Merge pull request #9957 from av8ramit/branch_155393864
Branch 155393864
This commit is contained in:
commit
ffd1ed2df7
@ -226,6 +226,10 @@ filegroup(
|
||||
"//tensorflow/contrib/copy_graph:all_files",
|
||||
"//tensorflow/contrib/crf:all_files",
|
||||
"//tensorflow/contrib/cudnn_rnn:all_files",
|
||||
"//tensorflow/contrib/data:all_files",
|
||||
"//tensorflow/contrib/data/python/framework:all_files",
|
||||
"//tensorflow/contrib/data/python/kernel_tests:all_files",
|
||||
"//tensorflow/contrib/data/python/ops:all_files",
|
||||
"//tensorflow/contrib/distributions:all_files",
|
||||
"//tensorflow/contrib/factorization:all_files",
|
||||
"//tensorflow/contrib/factorization/kernels:all_files",
|
||||
|
@ -738,7 +738,8 @@ tensorflow::string OutputName(const TF_Output& output) {
|
||||
const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
|
||||
const char* attr_name,
|
||||
TF_Status* status) {
|
||||
const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
|
||||
const tensorflow::AttrValue* attr =
|
||||
tensorflow::AttrSlice(oper->node.def()).Find(attr_name);
|
||||
if (attr == nullptr) {
|
||||
status->status =
|
||||
InvalidArgument("Operation has no attr named '", attr_name, "'.");
|
||||
@ -1134,7 +1135,7 @@ const char* TF_OperationOpType(TF_Operation* oper) {
|
||||
}
|
||||
|
||||
const char* TF_OperationDevice(TF_Operation* oper) {
|
||||
return oper->node.requested_device().c_str();
|
||||
return oper->node.def().device().c_str();
|
||||
}
|
||||
|
||||
int TF_OperationNumOutputs(TF_Operation* oper) {
|
||||
@ -1149,8 +1150,8 @@ TF_DataType TF_OperationOutputType(TF_Output oper_out) {
|
||||
int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
|
||||
TF_Status* status) {
|
||||
NameRangeMap name_ranges;
|
||||
status->status =
|
||||
NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
|
||||
status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
|
||||
nullptr, &name_ranges);
|
||||
if (!status->status.ok()) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
@ -1171,8 +1172,8 @@ TF_DataType TF_OperationInputType(TF_Input oper_in) {
|
||||
int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
|
||||
TF_Status* status) {
|
||||
NameRangeMap name_ranges;
|
||||
status->status =
|
||||
NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
|
||||
status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
|
||||
&name_ranges, nullptr);
|
||||
if (!status->status.ok()) return -1;
|
||||
auto iter = name_ranges.find(arg_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
@ -1410,27 +1411,26 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
|
||||
void func(TF_Operation* oper, const char* attr_name, c_type* value, \
|
||||
TF_Status* status) { \
|
||||
cpp_type v; \
|
||||
status->status = \
|
||||
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
|
||||
*value = static_cast<c_type>(v); \
|
||||
} \
|
||||
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
|
||||
int max_values, TF_Status* status) { \
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status); \
|
||||
if (!status->status.ok()) return; \
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) { \
|
||||
status->status = \
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list."); \
|
||||
return; \
|
||||
} \
|
||||
const auto len = std::min(max_values, attr->list().list_field##_size()); \
|
||||
for (int i = 0; i < len; ++i) { \
|
||||
values[i] = static_cast<c_type>(attr->list().list_field(i)); \
|
||||
} \
|
||||
#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
|
||||
void func(TF_Operation* oper, const char* attr_name, c_type* value, \
|
||||
TF_Status* status) { \
|
||||
cpp_type v; \
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &v); \
|
||||
*value = static_cast<c_type>(v); \
|
||||
} \
|
||||
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
|
||||
int max_values, TF_Status* status) { \
|
||||
const auto* attr = GetAttrValue(oper, attr_name, status); \
|
||||
if (!status->status.ok()) return; \
|
||||
if (attr->value_case() != tensorflow::AttrValue::kList) { \
|
||||
status->status = \
|
||||
InvalidArgument("Value for '", attr_name, "' is not a list."); \
|
||||
return; \
|
||||
} \
|
||||
const auto len = std::min(max_values, attr->list().list_field##_size()); \
|
||||
for (int i = 0; i < len; ++i) { \
|
||||
values[i] = static_cast<c_type>(attr->list().list_field(i)); \
|
||||
} \
|
||||
}
|
||||
DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
|
||||
DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
|
||||
@ -1441,8 +1441,7 @@ DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
|
||||
void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
|
||||
int64_t* value, int num_dims, TF_Status* status) {
|
||||
PartialTensorShape shape;
|
||||
status->status =
|
||||
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shape);
|
||||
if (!status->status.ok()) return;
|
||||
auto len = std::min(shape.dims(), num_dims);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
@ -1456,7 +1455,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
|
||||
int storage_size, TF_Status* status) {
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
status->status =
|
||||
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
|
||||
tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shapes);
|
||||
if (!status->status.ok()) return;
|
||||
auto len = std::min(static_cast<int>(shapes.size()), max_values);
|
||||
int64_t* p = storage;
|
||||
@ -1523,7 +1522,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
||||
TF_Tensor** value, TF_Status* status) {
|
||||
*value = nullptr;
|
||||
Tensor t;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &t);
|
||||
if (!status->status.ok()) return;
|
||||
*value = new TF_Tensor{static_cast<TF_DataType>(t.dtype()), t.shape(),
|
||||
tensorflow::TensorCApi::Buffer(t)};
|
||||
@ -1534,7 +1533,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
TF_Tensor** values, int max_values,
|
||||
TF_Status* status) {
|
||||
std::vector<Tensor> ts;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &ts);
|
||||
if (!status->status.ok()) return;
|
||||
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
||||
for (int i = 0; i < len; ++i) {
|
||||
|
@ -740,10 +740,11 @@ void OpInfo::GetOutput(string* out) const {
|
||||
return;
|
||||
}
|
||||
strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n");
|
||||
strings::StrAppend(out,
|
||||
" ::tensorflow::Status _status_ = "
|
||||
"::tensorflow::NameRangesForNode(*ret, ret->op_def(), "
|
||||
"nullptr, &_outputs_range);\n");
|
||||
strings::StrAppend(
|
||||
out,
|
||||
" ::tensorflow::Status _status_ = "
|
||||
"::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), "
|
||||
"nullptr, &_outputs_range);\n");
|
||||
strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str,
|
||||
".UpdateStatus(_status_);\n", " return;\n");
|
||||
strings::StrAppend(out, " }\n\n");
|
||||
|
@ -35,8 +35,8 @@ Output Linear(const Scope& scope, Input x, Input w, Input b) {
|
||||
void GetColocationConstraints(const Output& tensor,
|
||||
std::vector<string>* constraints) {
|
||||
constraints->clear();
|
||||
TF_EXPECT_OK(GetNodeAttr(tensor.op().node()->attrs(), kColocationAttrName,
|
||||
constraints));
|
||||
TF_EXPECT_OK(
|
||||
GetNodeAttr(tensor.op().node()->def(), kColocationAttrName, constraints));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -159,11 +159,11 @@ TEST(CCOpTest, KernelLabel) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f);
|
||||
TF_EXPECT_OK(root.status());
|
||||
AttrSlice attrs = add.z.op().node()->attrs();
|
||||
const auto* kernel_attr = attrs.Find("_kernel");
|
||||
ASSERT_TRUE(kernel_attr);
|
||||
TF_EXPECT_OK(AttrValueHasType(*kernel_attr, "string"));
|
||||
EXPECT_EQ(kernel_attr->s(), "AddWithKernelLabel");
|
||||
const auto& attrs = add.z.op().node()->def().attr();
|
||||
ASSERT_TRUE(attrs.find("_kernel") != attrs.end());
|
||||
auto kernel_attr = attrs.find("_kernel")->second;
|
||||
TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string"));
|
||||
EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel");
|
||||
}
|
||||
|
||||
TEST(CCOpTest, ColocateWith) {
|
||||
@ -190,7 +190,8 @@ TEST(CCOpTest, ColocateWith) {
|
||||
|
||||
Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4);
|
||||
auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7);
|
||||
EXPECT_FALSE(c6.op().node()->attrs().Find("_class"));
|
||||
const auto& attrs = c6.op().node()->def().attr();
|
||||
EXPECT_TRUE(attrs.find("_class") == attrs.end());
|
||||
}
|
||||
|
||||
TEST(CCOpTest, TemplatedConst) {
|
||||
|
@ -271,9 +271,9 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
|
||||
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
|
||||
const Operation& colocate_with_op) const {
|
||||
std::unordered_set<string> current_constraints(colocation_constraints_);
|
||||
const AttrSlice attrs = colocate_with_op.node()->attrs();
|
||||
const NodeDef& node_def = colocate_with_op.node()->def();
|
||||
std::vector<string> node_constraints;
|
||||
if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
|
||||
if (GetNodeAttr(node_def, kColocationAttrName, &node_constraints).ok()) {
|
||||
for (const string& entry : node_constraints) {
|
||||
StringPiece s(entry);
|
||||
if (s.Consume(kColocationGroupPrefix)) {
|
||||
|
@ -43,9 +43,9 @@ Status PackGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
int N;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N));
|
||||
int axis;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
|
||||
|
||||
grad_outputs->reserve(N);
|
||||
auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
|
||||
@ -60,7 +60,7 @@ Status UnpackGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
int axis;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis));
|
||||
grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
|
||||
return scope.status();
|
||||
}
|
||||
@ -162,7 +162,7 @@ Status CheckNumericsGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
string message;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "message", &message));
|
||||
string err_msg = strings::StrCat(
|
||||
"Not a number (NaN) or infinity (Inf) values detected in gradient. ",
|
||||
message);
|
||||
@ -215,9 +215,9 @@ Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto seq_lengths = op.input(1);
|
||||
int batch_dim;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim));
|
||||
int seq_dim;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim));
|
||||
grad_outputs->push_back(
|
||||
ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
|
||||
ReverseSequence::BatchDim(batch_dim)));
|
||||
@ -267,8 +267,7 @@ Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
int block_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
|
||||
grad_outputs->push_back(
|
||||
BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
@ -291,8 +290,7 @@ Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
int block_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
|
||||
grad_outputs->push_back(
|
||||
SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
@ -315,8 +313,7 @@ Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
int block_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
|
||||
grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
|
||||
return scope.status();
|
||||
}
|
||||
@ -326,8 +323,7 @@ Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
int block_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size));
|
||||
grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
|
||||
return scope.status();
|
||||
}
|
||||
@ -337,7 +333,7 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
string mode;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
|
||||
grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
|
||||
scope, grad_inputs[0], op.input(1), mode));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
@ -350,7 +346,7 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
string mode;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode));
|
||||
grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
|
@ -350,7 +350,7 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op,
|
||||
const string& attr_adj_x, const string& attr_adj_y,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
|
||||
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
|
||||
return errors::Unimplemented(
|
||||
"MatMul gradient for complex data type is not supported yet.");
|
||||
@ -358,10 +358,8 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op,
|
||||
|
||||
bool ta;
|
||||
bool tb;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta));
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
|
||||
|
||||
if (!ta && !tb) {
|
||||
return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1),
|
||||
|
@ -28,9 +28,9 @@ void ExpectNodeEqual(const Node* n, gtl::ArraySlice<T> values,
|
||||
TensorShape shape) {
|
||||
EXPECT_TRUE(n->IsConstant());
|
||||
Tensor tensor;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor));
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
|
||||
DataType dtype;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
|
||||
EXPECT_EQ(tensor.dtype(), dtype);
|
||||
test::ExpectTensorEqual<T>(tensor, test::AsTensor(values, shape));
|
||||
}
|
||||
@ -39,9 +39,9 @@ void ExpectTypeAndShape(const Node* n, DataType expected_dtype,
|
||||
TensorShape expected_shape) {
|
||||
EXPECT_TRUE(n->IsConstant());
|
||||
Tensor tensor;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor));
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor));
|
||||
DataType dtype;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
|
||||
EXPECT_EQ(dtype, expected_dtype);
|
||||
EXPECT_EQ(expected_shape, TensorShape(tensor.shape()));
|
||||
}
|
||||
|
@ -203,14 +203,14 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config,
|
||||
for (const Node* n : graph->nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
string feed_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id));
|
||||
if (missing_feeds.erase(feed_id) == 0) {
|
||||
return errors::Aborted(kArgOp,
|
||||
" node found with unknown feed id: ", feed_id);
|
||||
}
|
||||
} else if (n->type_string() == kRetvalOp) {
|
||||
string fetch_id;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id));
|
||||
if (missing_fetches.erase(fetch_id) == 0) {
|
||||
return errors::Aborted(kRetvalOp,
|
||||
" node found with unknown fetch id: ", fetch_id);
|
||||
@ -234,7 +234,7 @@ Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
|
||||
for (Node* n : graph.nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
|
||||
auto insert_result = indexed_arg_nodes.insert({index, n});
|
||||
if (!insert_result.second) {
|
||||
const Node* dup = insert_result.first->second;
|
||||
@ -264,9 +264,9 @@ Status CreateXlaArgs(const Graph& graph,
|
||||
for (const Node* node : arg_nodes) {
|
||||
XlaCompiler::Argument arg;
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name));
|
||||
xla_args->push_back(arg);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -66,9 +66,9 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
|
||||
|
||||
int num_constant_args, num_resource_args;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
|
||||
GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args));
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
|
||||
GetNodeAttr(node->def(), kXlaNumResourceArgsAttr, &num_resource_args));
|
||||
|
||||
if (num_constant_args < 0 || num_resource_args < 0 ||
|
||||
num_constant_args + num_resource_args > node->num_inputs()) {
|
||||
@ -88,7 +88,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
|
||||
Node* launch_node;
|
||||
TF_RETURN_IF_ERROR(BuildLaunchNode(
|
||||
graph->NewName(node->name()), node->type_string(), node->def().attr(),
|
||||
node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
|
||||
node->def().device(), const_dtypes, num_resource_args, arg_dtypes,
|
||||
node->output_types(), graph, &launch_node));
|
||||
launch_node->set_assigned_device_name(node->assigned_device_name());
|
||||
|
||||
@ -173,8 +173,7 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If ndef is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
|
||||
TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle));
|
||||
const FunctionBody* fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(fbody); // Can't be nullptr since we just instantiated it.
|
||||
std::vector<bool> const_args(fbody->arg_types.size());
|
||||
|
@ -165,7 +165,7 @@ static const char* const kRetValOp = "_Retval";
|
||||
// none.
|
||||
string Encapsulator::GetFunctionNameAttr(Node const* node) const {
|
||||
string attr;
|
||||
if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) {
|
||||
if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) {
|
||||
attr.clear();
|
||||
}
|
||||
return attr;
|
||||
@ -195,7 +195,7 @@ Status Encapsulator::SplitIntoSubgraphs() {
|
||||
|
||||
// Check the device matches any existing device.
|
||||
string device = node->assigned_device_name().empty()
|
||||
? node->requested_device()
|
||||
? node->def().device()
|
||||
: node->assigned_device_name();
|
||||
|
||||
if (subgraph.device.empty()) {
|
||||
@ -593,7 +593,7 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
|
||||
for (Node* n : graph.nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
|
||||
if (index < 0 || index >= types->size()) {
|
||||
return errors::InvalidArgument("Invalid argument number");
|
||||
}
|
||||
@ -610,7 +610,7 @@ static Status RenumberArguments(Graph* graph,
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (n->type_string() == kArgOp) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
|
||||
if (index < 0 || index >= permutation.size()) {
|
||||
return errors::InvalidArgument("Invalid argument number");
|
||||
}
|
||||
@ -713,7 +713,7 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
bool IsXlaCompiledKernel(const Node& node) {
|
||||
bool is_compiled = false;
|
||||
bool has_compilation_attr =
|
||||
GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
|
||||
GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
|
||||
is_compiled;
|
||||
return has_compilation_attr ? is_compiled : false;
|
||||
}
|
||||
|
@ -126,8 +126,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||
if (node->type_string() == kArgOp) {
|
||||
int index;
|
||||
DataType type;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index));
|
||||
while (fdef->signature().input_arg_size() <= index) {
|
||||
fdef->mutable_signature()->add_input_arg();
|
||||
}
|
||||
@ -143,8 +143,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||
if (node->type_string() == kRetValOp) {
|
||||
int index;
|
||||
DataType type;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index));
|
||||
while (fdef->signature().output_arg_size() <= index) {
|
||||
fdef->mutable_signature()->add_output_arg();
|
||||
}
|
||||
@ -161,7 +161,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||
}
|
||||
|
||||
NodeDef* node_def = fdef->add_node_def();
|
||||
*node_def = node->def();
|
||||
node_def->CopyFrom(node->def());
|
||||
node_def->set_name(node_names.Uniquify(node->name()));
|
||||
|
||||
// Reset input names based on graph rather than the NodeDef.
|
||||
@ -203,8 +203,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
|
||||
|
||||
// Populate tensor_renaming.
|
||||
NameRangeMap output_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
|
||||
TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr,
|
||||
&output_ranges));
|
||||
for (const auto& output : output_ranges) {
|
||||
for (int i = output.second.first; i < output.second.second; ++i) {
|
||||
const string tensor_name = strings::StrCat(
|
||||
|
@ -55,7 +55,7 @@ XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx)
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
|
||||
function_ = *func;
|
||||
VLOG(1) << "XlaDeviceLaunch created function="
|
||||
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
||||
<< Canonicalize(function_.name(), function_.attr());
|
||||
DataTypeVector constant_types;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
|
||||
num_constant_args_ = constant_types.size();
|
||||
@ -81,7 +81,7 @@ std::vector<OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
|
||||
void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "XlaDeviceLaunch::Compute "
|
||||
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
||||
<< Canonicalize(function_.name(), function_.attr());
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
|
@ -186,7 +186,7 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
|
||||
|
||||
void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "XlaLocalLaunchOp::Compute "
|
||||
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
||||
<< Canonicalize(function_.name(), function_.attr());
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
|
@ -56,18 +56,18 @@ bool IsCompilableCall(const NodeDef& call_def,
|
||||
const DeviceType& jit_device_type, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime);
|
||||
|
||||
// Tests whether 'while_node' is a completely compilable loop.
|
||||
// Tests whether 'while_def' is a completely compilable loop.
|
||||
// Every operator in the condition and body functions must be compilable for a
|
||||
// while loop to be compilable.
|
||||
bool IsCompilableWhile(const Node& while_node,
|
||||
bool IsCompilableWhile(const NodeDef& while_def,
|
||||
const DeviceType& jit_device_type, int depth,
|
||||
FunctionLibraryRuntime* lib_runtime) {
|
||||
VLOG(2) << "Loop marking: " << while_node.type_string();
|
||||
VLOG(2) << "Loop marking: " << while_def.op();
|
||||
|
||||
const NameAttrList* name_attr;
|
||||
NodeDef call;
|
||||
Status status;
|
||||
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
|
||||
status = GetNodeAttr(while_def, "cond", &name_attr);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Missing 'cond' attribute on While node.";
|
||||
return false;
|
||||
@ -80,7 +80,7 @@ bool IsCompilableWhile(const Node& while_node,
|
||||
VLOG(2) << "Can't compile loop condition: " << cond_func;
|
||||
return false;
|
||||
}
|
||||
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
|
||||
status = GetNodeAttr(while_def, "body", &name_attr);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Missing 'body' attribute on While node.";
|
||||
return false;
|
||||
@ -112,7 +112,7 @@ bool IsCompilableCall(const NodeDef& call_def,
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status =
|
||||
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
|
||||
lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle);
|
||||
if (!status.ok()) {
|
||||
VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
|
||||
return false;
|
||||
@ -134,11 +134,11 @@ bool IsCompilableCall(const NodeDef& call_def,
|
||||
|
||||
for (Node* node : fbody->graph->nodes()) {
|
||||
if (node->IsSource() || node->IsSink()) continue;
|
||||
if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
|
||||
continue;
|
||||
if (node->type_string() == "While") {
|
||||
if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue;
|
||||
if (node->def().op() == "While") {
|
||||
// Handle functional While loop (not in open source build).
|
||||
return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime);
|
||||
return IsCompilableWhile(node->def(), jit_device_type, depth + 1,
|
||||
lib_runtime);
|
||||
}
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
|
||||
@ -192,16 +192,17 @@ Status FindCompilationCandidates(
|
||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
|
||||
VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
|
||||
<< ": " << node->type_string();
|
||||
<< ": " << node->def().op();
|
||||
continue;
|
||||
}
|
||||
if (!registration->compile_resource_ops && HasResourceArgument(*node)) {
|
||||
VLOG(2) << "Compilation rejected node: resource argument " << node->name()
|
||||
<< ": " << node->type_string();
|
||||
<< ": " << node->def().op();
|
||||
continue;
|
||||
}
|
||||
if (node->type_string() == "While" &&
|
||||
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) {
|
||||
if (node->def().op() == "While" &&
|
||||
!IsCompilableWhile(node->def(), jit_device_type, 0,
|
||||
lib_runtime.get())) {
|
||||
continue;
|
||||
}
|
||||
candidates->insert(node);
|
||||
@ -318,10 +319,10 @@ Status MarkForCompilationPass::Run(
|
||||
|
||||
// If there is a _XlaCompile annotation, use its value.
|
||||
bool compile = false;
|
||||
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
|
||||
Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile);
|
||||
if (status.ok()) return compile;
|
||||
|
||||
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
|
||||
status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile);
|
||||
if (status.ok()) return compile;
|
||||
|
||||
// Otherwise use the value of global_jit_level.
|
||||
@ -484,8 +485,8 @@ Status MarkForCompilationPass::RunImpl(
|
||||
// all nodes marked with _XlaCompile=true to also have a
|
||||
// _XlaScope property set (and raise an error otherwise); but
|
||||
// for now we don't do this.
|
||||
if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
|
||||
GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() &&
|
||||
if (GetNodeAttr(node_from->def(), kXlaScopeAttr, &from_scope).ok() &&
|
||||
GetNodeAttr(node_to->def(), kXlaScopeAttr, &to_scope).ok() &&
|
||||
from_scope != to_scope) {
|
||||
continue;
|
||||
}
|
||||
@ -540,9 +541,10 @@ Status MarkForCompilationPass::RunImpl(
|
||||
// Compile if the user marked this node _XlaCompile=true
|
||||
bool compile_attr = false;
|
||||
bool marked_for_compilation = false;
|
||||
if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
|
||||
if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) {
|
||||
marked_for_compilation = compile_attr;
|
||||
} else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr)
|
||||
} else if (options.flib_def
|
||||
->GetAttr(n->def(), kXlaCompileAttr, &compile_attr)
|
||||
.ok()) {
|
||||
marked_for_compilation = compile_attr;
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
|
||||
std::unordered_map<string, string> ids;
|
||||
for (Node* node : graph.nodes()) {
|
||||
string cluster;
|
||||
if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
|
||||
if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) {
|
||||
CHECK(!cluster.empty());
|
||||
ids[node->name()] = cluster;
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ Status XlaCompilationCache::BuildSignature(
|
||||
const NameAttrList& function, int num_constant_args,
|
||||
const std::vector<OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
Signature* signature) {
|
||||
signature->name = Canonicalize(function.name(), AttrSlice(&function.attr()));
|
||||
signature->name = Canonicalize(function.name(), function.attr());
|
||||
signature->arg_values.resize(num_constant_args);
|
||||
|
||||
signature->arg_types.reserve(ctx->num_inputs() - num_constant_args);
|
||||
|
@ -108,7 +108,7 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
if (must_be_const.find(node) != must_be_const.end()) {
|
||||
if (node->type_string() == "_Arg") {
|
||||
int index;
|
||||
status = GetNodeAttr(node->attrs(), "index", &index);
|
||||
status = GetNodeAttr(node->def(), "index", &index);
|
||||
if (!status.ok()) return;
|
||||
compile_time_const_args->at(index) = true;
|
||||
return;
|
||||
@ -124,8 +124,8 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
if (range.first == range.second) return;
|
||||
|
||||
NameRangeMap input_name_ranges;
|
||||
status =
|
||||
NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr);
|
||||
status = NameRangesForNode(node->def(), node->op_def(), &input_name_ranges,
|
||||
nullptr);
|
||||
if (!status.ok()) return;
|
||||
|
||||
for (auto it = range.first; it != range.second; ++it) {
|
||||
|
@ -68,8 +68,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
|
||||
done);
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_),
|
||||
done);
|
||||
ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done);
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
|
@ -106,8 +106,7 @@ Status XlaCompiler::CompileFunction(
|
||||
const XlaCompiler::CompileOptions& options, const NameAttrList& function,
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
XlaCompiler::CompilationResult* result) {
|
||||
const string function_id =
|
||||
Canonicalize(function.name(), AttrSlice(&function.attr()));
|
||||
const string function_id = Canonicalize(function.name(), function.attr());
|
||||
VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
|
||||
|
||||
auto it = cache_.find({function_id, args});
|
||||
@ -117,8 +116,8 @@ Status XlaCompiler::CompileFunction(
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
TF_RETURN_IF_ERROR(flib_runtime_->Instantiate(
|
||||
function.name(), AttrSlice(&function.attr()), &handle));
|
||||
TF_RETURN_IF_ERROR(
|
||||
flib_runtime_->Instantiate(function.name(), function.attr(), &handle));
|
||||
|
||||
const FunctionBody* fbody = flib_runtime_->GetFunctionBody(handle);
|
||||
CHECK(fbody);
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -168,6 +167,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleReverse(HloInstruction* reverse,
|
||||
HloInstruction* operand) override;
|
||||
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
|
||||
Status HandleDynamicSlice(HloInstruction* slice, HloInstruction* operand,
|
||||
HloInstruction* start_indices) override;
|
||||
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* update,
|
||||
@ -1025,6 +1026,15 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
|
||||
HloInstruction* dynamic_slice, HloInstruction* operand,
|
||||
HloInstruction* start_indices) {
|
||||
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
|
||||
return ReplaceInstruction(dynamic_slice, operand);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
|
||||
HloInstruction* dynamic_update_slice, HloInstruction* operand,
|
||||
HloInstruction* update, HloInstruction* start_indices) {
|
||||
|
@ -196,9 +196,9 @@ class DfsHloVisitor {
|
||||
tensorflow::StringPiece custom_call_target) = 0;
|
||||
virtual Status HandleSlice(HloInstruction* slice,
|
||||
HloInstruction* operand) = 0;
|
||||
virtual Status HandleDynamicSlice(
|
||||
HloInstruction* slice,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
|
||||
virtual Status HandleDynamicSlice(HloInstruction* dynamic_slice,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* start_indices) = 0;
|
||||
virtual Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* update,
|
||||
|
@ -134,10 +134,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
|
||||
HloInstruction* /*operand*/) override {
|
||||
return DefaultAction(slice);
|
||||
}
|
||||
Status HandleDynamicSlice(
|
||||
HloInstruction* slice,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/) override {
|
||||
return DefaultAction(slice);
|
||||
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
|
||||
HloInstruction* /*operand*/,
|
||||
HloInstruction* /*start_indices*/) override {
|
||||
return DefaultAction(dynamic_slice);
|
||||
}
|
||||
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
|
||||
HloInstruction* /*operand*/,
|
||||
|
@ -136,9 +136,9 @@ Status HloCostAnalysis::HandleSlice(HloInstruction* slice,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleDynamicSlice(
|
||||
HloInstruction* slice,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
||||
Status HloCostAnalysis::HandleDynamicSlice(HloInstruction* dynamic_slice,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* start_indices) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -89,9 +89,9 @@ class HloCostAnalysis : public DfsHloVisitor {
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
|
||||
tensorflow::StringPiece custom_call_target) override;
|
||||
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
|
||||
Status HandleDynamicSlice(
|
||||
HloInstruction* slice,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
|
||||
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* start_indices) override;
|
||||
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* update,
|
||||
|
@ -1782,7 +1782,7 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
||||
case HloOpcode::kSlice:
|
||||
return visitor->HandleSlice(this, operands_[0]);
|
||||
case HloOpcode::kDynamicSlice:
|
||||
return visitor->HandleDynamicSlice(this, operands_);
|
||||
return visitor->HandleDynamicSlice(this, operands_[0], operands_[1]);
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
return visitor->HandleDynamicUpdateSlice(this, operands_[0], operands_[1],
|
||||
operands_[2]);
|
||||
|
@ -20,6 +20,7 @@ py_library(
|
||||
"//tensorflow/contrib/copy_graph:copy_graph_py",
|
||||
"//tensorflow/contrib/crf:crf_py",
|
||||
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/contrib/deprecated:deprecated_py",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/factorization:factorization_py",
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.contrib import compiler
|
||||
from tensorflow.contrib import copy_graph
|
||||
from tensorflow.contrib import crf
|
||||
from tensorflow.contrib import cudnn_rnn
|
||||
from tensorflow.contrib import data
|
||||
from tensorflow.contrib import deprecated
|
||||
from tensorflow.contrib import distributions
|
||||
from tensorflow.contrib import factorization
|
||||
|
@ -14,6 +14,23 @@ For prebuilt libraries, see the
|
||||
[nightly Android build artifacts](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)
|
||||
page for a recent build.
|
||||
|
||||
The TensorFlow Inference Interface is also available as a
|
||||
[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and
|
||||
can be included quite simply in your android project with a couple of lines in
|
||||
the project's `build.gradle` file:
|
||||
|
||||
```
|
||||
allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compile 'org.tensorflow:tensorflow-android:1.2.0-preview'
|
||||
}
|
||||
```
|
||||
|
||||
To build the libraries yourself (if, for example, you want to support custom
|
||||
TensorFlow operators), pick your preferred approach below:
|
||||
|
||||
|
@ -18,6 +18,7 @@ set(tf_op_lib_names
|
||||
"control_flow_ops"
|
||||
"ctc_ops"
|
||||
"data_flow_ops"
|
||||
"dataset_ops"
|
||||
"functional_ops"
|
||||
"image_ops"
|
||||
"io_ops"
|
||||
|
@ -265,6 +265,11 @@ add_python_module("tensorflow/contrib/cudnn_rnn/ops")
|
||||
add_python_module("tensorflow/contrib/cudnn_rnn/python")
|
||||
add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests")
|
||||
add_python_module("tensorflow/contrib/cudnn_rnn/python/ops")
|
||||
add_python_module("tensorflow/contrib/data")
|
||||
add_python_module("tensorflow/contrib/data/python")
|
||||
add_python_module("tensorflow/contrib/data/python/framework")
|
||||
add_python_module("tensorflow/contrib/data/python/kernel_tests")
|
||||
add_python_module("tensorflow/contrib/data/python/ops")
|
||||
add_python_module("tensorflow/contrib/deprecated")
|
||||
add_python_module("tensorflow/contrib/distributions")
|
||||
add_python_module("tensorflow/contrib/distributions/python")
|
||||
@ -592,6 +597,7 @@ GENERATE_PYTHON_OP_LIB("control_flow_ops"
|
||||
ADDITIONAL_LIBRARIES $<TARGET_OBJECTS:tf_no_op>)
|
||||
GENERATE_PYTHON_OP_LIB("ctc_ops")
|
||||
GENERATE_PYTHON_OP_LIB("data_flow_ops")
|
||||
GENERATE_PYTHON_OP_LIB("dataset_ops")
|
||||
GENERATE_PYTHON_OP_LIB("image_ops")
|
||||
GENERATE_PYTHON_OP_LIB("io_ops")
|
||||
GENERATE_PYTHON_OP_LIB("linalg_ops")
|
||||
|
@ -144,6 +144,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/stateless/python/kernel_tests/*_test.py"
|
||||
@ -204,6 +205,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py"
|
||||
# Broken tensorboard test due to cmake issues.
|
||||
"${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
||||
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
|
||||
|
27
tensorflow/contrib/data/BUILD
Normal file
27
tensorflow/contrib/data/BUILD
Normal file
@ -0,0 +1,27 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "data",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
630
tensorflow/contrib/data/README.md
Normal file
630
tensorflow/contrib/data/README.md
Normal file
@ -0,0 +1,630 @@
|
||||
# Using the `Dataset` API for TensorFlow Input Pipelines
|
||||
|
||||
The `Dataset` API is designed to let you build complex input pipelines from
|
||||
simple, reusable pieces. For example, the pipeline for an image model might
|
||||
aggregate data from files in a distributed file system, apply random
|
||||
perturbations to each image, and merge randomly selected images into a batch
|
||||
for training. The pipeline for a text model might involve extracting symbols
|
||||
from raw text data, converting them to embedding identifiers with a lookup
|
||||
table, and batching together sequences of different lengths. The `Dataset` API
|
||||
makes it easy to deal with large amounts of data, different data formats, and
|
||||
complicated transformations.
|
||||
|
||||
The `Dataset` API introduces two new abstractions to TensorFlow:
|
||||
|
||||
* A `tf.contrib.data.Dataset` represents a sequence of elements, in which
|
||||
each element contains one or more `Tensor` objects. For example, in an image
|
||||
pipeline, an element might be a single training example, with a pair of
|
||||
tensors representing the image data and a label. A `Dataset` can either be a
|
||||
*source* (e.g. `Dataset.from_tensor_slices()` constructs a dataset from one
|
||||
or more `tf.Tensor` objects), or a *transformation* (e.g. `Dataset.batch()`
|
||||
constructs a dataset by stacking consecutive elements of another dataset into
|
||||
a single element).
|
||||
|
||||
* A `tf.contrib.data.Iterator` provides the main way to extract elements from a
|
||||
dataset. The `Iterator.get_next()` operation yields the next element of a
|
||||
`Dataset`, and typically acts as the interface between input pipeline code and
|
||||
your model. The simplest iterator is a "one-shot iterator", which is
|
||||
associated with a particular `Dataset` and iterates through it once. For more
|
||||
sophisticated uses, the `Iterator.initializer` operation enables you to
|
||||
reinitialize and parameterize an iterator with different datasets, so that
|
||||
you can, for example, iterate over training and validation data multiple times
|
||||
in the same program.
|
||||
|
||||
## Tutorial
|
||||
|
||||
This programmers' guide includes step-by-step instructions for a variety of
|
||||
input data use cases. Also see the `Dataset` and `Iterator` class references
|
||||
for more detailed information about the API.
|
||||
|
||||
### Basic mechanics
|
||||
|
||||
This section of the guide describes the fundamentals of creating different kinds
|
||||
of `Dataset` and `Iterator` objects, and how to extract data from them.
|
||||
|
||||
#### Defining a source dataset
|
||||
|
||||
You can build a `Dataset` using one of the following *source* dataset
|
||||
constructors:
|
||||
|
||||
* From in-memory data:
|
||||
* `tf.contrib.data.Dataset.from_tensors()`
|
||||
* `tf.contrib.data.Dataset.from_tensor_slices()`
|
||||
|
||||
* From on-disk data:
|
||||
* `tf.contrib.data.FixedLengthRecordDataset()`
|
||||
* `tf.contrib.data.TextLineDataset()`
|
||||
* `tf.contrib.data.TFRecordDataset()`
|
||||
|
||||
* From parameters:
|
||||
* `tf.contrib.data.Dataset.range()`
|
||||
|
||||
#### Transforming a dataset
|
||||
|
||||
The `tf.contrib.data.Dataset` class has many methods that can be chained
|
||||
together to *transform* one dataset into another:
|
||||
|
||||
* Per-element transformations:
|
||||
* `Dataset.filter()`
|
||||
* `Dataset.flat_map()`
|
||||
* `Dataset.map()`
|
||||
* `Dataset.zip()`
|
||||
|
||||
* Multi-element transformations:
|
||||
* `Dataset.batch()`
|
||||
* `Dataset.dense_to_sparse_batch()`
|
||||
* `Dataset.group_by_window()`
|
||||
* `Dataset.padded_batch()`
|
||||
* `Dataset.repeat()`
|
||||
* `Dataset.shuffle()`
|
||||
* `Dataset.skip()`
|
||||
* `Dataset.take()`
|
||||
|
||||
The following sections contain examples of how to use these transformations to
|
||||
solve common problems.
|
||||
|
||||
#### Dataset structure
|
||||
|
||||
A dataset comprises elements that each have the same structure. An element
|
||||
contains one or more `tf.Tensor` objects, called *components*. Each component
|
||||
has a `tf.DType` representing the type of elements in the tensor, and a
|
||||
`tf.TensorShape` representing the (possibly partially specified) static shape of
|
||||
each element. The `Dataset.output_types` and `Dataset.output_shapes` properties
|
||||
allow you to inspect the inferred types and shapes of each component of a
|
||||
dataset element. The *nested structure* of these properties map to the structure
|
||||
of an element, which may be a single tensor, a tuple of tensors, or a nested
|
||||
tuple of tensors. For example:
|
||||
|
||||
```python
|
||||
dataset1 = tf.contrib.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
|
||||
print(dataset1.output_types) # ==> "tf.float32"
|
||||
print(dataset1.output_shapes) # ==> "(10,)"
|
||||
|
||||
dataset2 = tf.contrib.data.Dataset.from_tensor_slices(
|
||||
(tf.random_uniform([4]), tf.random_uniform([4, 100], dtype=tf.int32)))
|
||||
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
|
||||
print(dataset2.output_shapes) # ==> "((), (100,))"
|
||||
|
||||
dataset3 = tf.contrib.data.Dataset.zip((dataset1, dataset2))
|
||||
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
|
||||
print(dataset3.output_shapes) # ==> "((), (100,))"
|
||||
```
|
||||
|
||||
The `Dataset` transformations support datasets of any structure. When using the
|
||||
`Dataset.map()`, `Dataset.flat_map()` and `Dataset.filter()` transformations,
|
||||
which apply a function to each element, the element structure determines the
|
||||
arguments of the function:
|
||||
|
||||
```python
|
||||
dataset1 = dataset1.map(lambda x: ...)
|
||||
|
||||
dataset2 = dataset2.flat_map(lambda x, y: ...)
|
||||
|
||||
# *N.B.* Lambda argument destructuring is not available in Python 3.
|
||||
dataset3 = dataset3.filter(lambda x, (y, z): ...)
|
||||
```
|
||||
|
||||
#### Creating an iterator
|
||||
|
||||
One you have built a `Dataset` to represent your input data, the next step is to
|
||||
create an `Iterator` to access elements from that dataset. The `Dataset` API
|
||||
currently supports three kinds of iterator, in increasing level of
|
||||
sophistication:
|
||||
|
||||
A *one-shot* iterator is the simplest form of iterator, which only supports
|
||||
iterating once through a dataset, with no need for explicit initialization.
|
||||
One-shot iterators handle almost all of the cases that the existing queue-based
|
||||
input pipelines support, but they do not support parameterization. Using the
|
||||
example of `Dataset.range()`:
|
||||
|
||||
```python
|
||||
dataset = tf.contrib.data.Dataset.range(100)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
for i in range(100):
|
||||
value = sess.run(next_element)
|
||||
assert i == value
|
||||
```
|
||||
|
||||
An *initializable* iterator requires you to run an explicit
|
||||
`iterator.initializer` operation before using it. In exchange for this
|
||||
inconvenience, it enables you to *parameterize* the definition of the dataset,
|
||||
using one or more `tf.placeholder()` tensors that can be fed when you
|
||||
initialize the iterator. Continuing the `Dataset.range()` example:
|
||||
|
||||
```python
|
||||
max_value = tf.placeholder(tf.int64, shape=[])
|
||||
dataset = tf.contrib.data.Dataset.range(max_value)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
# Initialize an iterator over a dataset with 10 elements.
|
||||
sess.run(iterator.initializer, feed_dict={max_value: 10})
|
||||
for i in range(10):
|
||||
value = sess.run(next_element)
|
||||
assert i == value
|
||||
|
||||
# Initialize the same iterator over a dataset with 100 elements.
|
||||
sess.run(iterator.initializer, feed_dict={max_value: 100})
|
||||
for i in range(100):
|
||||
value = sess.run(next_element)
|
||||
assert i == value
|
||||
```
|
||||
|
||||
A *reinitializable* iterator can be initialized from multiple different
|
||||
`Dataset` objects. For example, you might have a training input pipeline that
|
||||
uses random perturbations to the input images to improve generalization, and
|
||||
a validation input pipeline that evaluates predictions on unmodified data. These
|
||||
pipelines will typically use different `Dataset` objects that have the same
|
||||
structure (i.e. the same types and compatible shapes for each component).
|
||||
|
||||
```python
|
||||
training_dataset = tf.contrib.data.Dataset.range(100).map(
|
||||
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
|
||||
validation_dataset = tf.contrib.data.Dataset.range(50)
|
||||
|
||||
# A reinitializable iterator is defined by its structure. We could use the
|
||||
# `output_types` and `output_shapes` properties of either `training_dataset`
|
||||
# or `validation_dataset` here, because they are compatible.
|
||||
iterator = Iterator.from_structure(training_dataset.output_types,
|
||||
training_dataset.output_shapes)
|
||||
next_element = iterator.get_next()
|
||||
|
||||
training_init_op = iterator.make_initializer(training_dataset)
|
||||
validation_init_op = iterator.make_initializer(validation_dataset)
|
||||
|
||||
# Run 20 epochs in which the training dataset is traversed, followed by the
|
||||
# validation dataset.
|
||||
for _ in range(20):
|
||||
# Initialize an iterator over the training dataset.
|
||||
sess.run(training_init_op)
|
||||
for _ in range(100):
|
||||
sess.run(next_element)
|
||||
|
||||
# Initialize an iterator over the validation dataset.
|
||||
sess.run(validation_init_op)
|
||||
for _ in range(50):
|
||||
sess.run(next_element)
|
||||
```
|
||||
|
||||
#### Consuming values from an iterator
|
||||
|
||||
The `Iterator.get_next()` method returns one or more `tf.Tensor` objects that
|
||||
correspond to the symbolic next element of an iterator. Each time these tensors
|
||||
are evaluated, they take the value of the next element in the underlying
|
||||
dataset. (Note that, like other stateful objects in TensorFlow, calling
|
||||
`Iterator.get_next()` does not immediately advance the iterator. Instead you
|
||||
must use the returned `tf.Tensor` objects in a TensorFlow expression, and pass
|
||||
the result of that expression to `tf.Session.run()` to get the next elements and
|
||||
advance the iterator.)
|
||||
|
||||
If the iterator reaches the end of the dataset, executing
|
||||
the `Iterator.get_next()` operation will raise a `tf.errors.OutOfRangeError`.
|
||||
After this point the iterator will be in an unusable state, and you must
|
||||
initialize it again if you want to use it further.
|
||||
|
||||
```python
|
||||
dataset = tf.contrib.data.Dataset.range(5)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
# Typically `result` will be the output of a model, or an optimizer's
|
||||
# training operation.
|
||||
result = tf.add(next_element, next_element)
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
print(sess.run(result)) # ==> "0"
|
||||
print(sess.run(result)) # ==> "2"
|
||||
print(sess.run(result)) # ==> "4"
|
||||
print(sess.run(result)) # ==> "6"
|
||||
print(sess.run(result)) # ==> "8"
|
||||
try:
|
||||
sess.run(result)
|
||||
except tf.errors.OutOfRangeError:
|
||||
print("End of dataset") # ==> "End of dataset"
|
||||
```
|
||||
|
||||
A common pattern is to wrap the "training loop" in a `try`-`except` block:
|
||||
|
||||
```python
|
||||
sess.run(iterator.initializer)
|
||||
while True:
|
||||
try:
|
||||
sess.run(result)
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
```
|
||||
|
||||
If each element of the dataset has a nested structure, the return value of
|
||||
`Iterator.get_next()` will be one or more `tf.Tensor` objects in the same
|
||||
nested structure:
|
||||
|
||||
```python
|
||||
dataset1 = tf.contrib.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
|
||||
dataset2 = tf.contrib.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
|
||||
dataset3 = tf.contrib.data.Dataset.zip((dataset1, dataset2))
|
||||
|
||||
iterator = dataset3.make_initializable_iterator()
|
||||
|
||||
sess.run(iterator.initializer)
|
||||
next1, (next2, next3) = iterator.get_next()
|
||||
```
|
||||
|
||||
Note that evaluating *any* of `next1`, `next2`, or `next3` will advance the
|
||||
iterator for all components. A typical consumer of an iterator will include all
|
||||
components in a single expression.
|
||||
|
||||
### Reading input data
|
||||
|
||||
#### Consuming NumPy arrays
|
||||
|
||||
If all of your input data fit in memory, the simplest way to create a `Dataset`
|
||||
from them is to convert them to `tf.Tensor` objects and use
|
||||
`Dataset.from_tensor_slices()`.
|
||||
|
||||
```python
|
||||
# Load the training data into two NumPy arrays, for example using `np.load()`.
|
||||
with np.load("/var/data/training_data.npy") as data:
|
||||
features = data["features"]
|
||||
labels = data["labels"]
|
||||
|
||||
# Assume that each row of `features` corresponds to the same row as `labels`.
|
||||
assert features.shape[0] == labels.shape[0]
|
||||
|
||||
dataset = tf.contrib.data.Dataset.from_tensor_slices((features, labels))
|
||||
```
|
||||
|
||||
Note that the above code snippet will embed the `features` and `labels` arrays
|
||||
in your TensorFlow graph as constants. This works well for a small dataset, but
|
||||
wastes memory, and can run into the 2GB limit for the `tf.GraphDef` protocol
|
||||
buffer.
|
||||
|
||||
As an alternative, you can define the `Dataset` in terms of `tf.placeholder()`
|
||||
tensors, and *feed* the NumPy arrays when you initialize an `Iterator` over the
|
||||
dataset.
|
||||
|
||||
```python
|
||||
# Load the training data into two NumPy arrays, for example using `np.load()`.
|
||||
with np.load("/var/data/training_data.npy") as data:
|
||||
features = data["features"]
|
||||
labels = data["labels"]
|
||||
|
||||
# Assume that each row of `features` corresponds to the same row as `labels`.
|
||||
assert features.shape[0] == labels.shape[0]
|
||||
|
||||
features_placeholder = tf.placeholder(features.dtype, features.shape)
|
||||
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
|
||||
|
||||
dataset = tf.contrib.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
|
||||
# [Other transformations on `dataset`...]
|
||||
dataset = ...
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
|
||||
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
|
||||
labels_placeholder: labels})
|
||||
```
|
||||
|
||||
#### Consuming TFRecord data
|
||||
|
||||
The `Dataset` API supports a variety of file formats so that you can process
|
||||
large datasets that do not fit in memory. The TFRecord file format is a
|
||||
simple record-oriented binary format that many TensorFlow applications use for
|
||||
training data. The `tf.contrib.data.TFRecordDataset` class enables you to
|
||||
stream over the contents of one or more TFRecord files as part of an input
|
||||
pipeline.
|
||||
|
||||
```python
|
||||
# Creates a dataset that reads all of the examples from two files.
|
||||
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
|
||||
dataset = tf.contrib.data.TFRecordDataset(filenames)
|
||||
```
|
||||
|
||||
The `filenames` argument to the `TFRecordDataset` initializer can be a
|
||||
`tf.Tensor` of strings. Therefore if you have two sets of files for training
|
||||
and validation purposes, you can use a `tf.placeholder(tf.string)` to represent
|
||||
the filenames, and initialize an iterator from the appropriate filenames:
|
||||
|
||||
```python
|
||||
filenames = tf.placeholder(tf.string, shape=[None])
|
||||
dataset = tf.contrib.data.TFRecordDataset(filenames)
|
||||
# [Other transformations on `dataset`...]
|
||||
dataset = ...
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
|
||||
# You can feed the initializer with the appropriate filenames for the current
|
||||
# phase of execution, e.g. training vs. validation.
|
||||
|
||||
# Initialize `iterator` with training data.
|
||||
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
|
||||
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
|
||||
|
||||
# Initialize `iterator` with validation data.
|
||||
validation_filenames = ["/var/data/validation1.tfrecord", ...]
|
||||
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
|
||||
```
|
||||
|
||||
#### Consuming text data
|
||||
|
||||
Many datasets are distributed as one or more text files. The
|
||||
`tf.contrib.data.TextLineDataset` provides an easy way to extract lines from
|
||||
one or more text files. Given one or more filenames, a `TextLineDataset` will
|
||||
produce one string-valued element per line of those files. Like a
|
||||
`TFRecordDataset`, `TextLineDataset` accepts `filenames` as a `tf.Tensor`, so
|
||||
you can parameterize it by passing a `tf.placeholder(tf.string)`.
|
||||
|
||||
```python
|
||||
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
|
||||
dataset = tf.contrib.data.TextLineDataset(filenames)
|
||||
```
|
||||
|
||||
By default, a `TextLineDataset` yields *every* line of each file, which may
|
||||
not be desirable, for example if the file starts with a header line, or contains
|
||||
comments. These lines can be removed using the `Dataset.skip()` and
|
||||
`Dataset.filter()` transformations. To apply these transformations to each
|
||||
file separately, we use `Dataset.flat_map()` to create a nested `Dataset` for
|
||||
each file.
|
||||
|
||||
```python
|
||||
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
|
||||
|
||||
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames)
|
||||
|
||||
# Use `Dataset.flat_map()` to transform each file separately.
|
||||
# * Skip the first line (header row).
|
||||
# * Filter out lines beginning with "#" (comments).
|
||||
dataset = dataset.flat_map(
|
||||
lambda filename: (
|
||||
tf.contrib.data.Dataset.TextLineDataset(filename)
|
||||
.skip(1)
|
||||
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
|
||||
```
|
||||
|
||||
<!--
|
||||
TODO(mrry): Add these sections.
|
||||
|
||||
#### Consuming from a Python generator
|
||||
#### Consuming from an index file and images
|
||||
-->
|
||||
|
||||
### Preprocessing data with `Dataset.map()`
|
||||
|
||||
The `Dataset.map(f)` transformation produces a new dataset by applying a given
|
||||
function `f` to each element of the input dataset. It is based on
|
||||
the
|
||||
[`map()` function](https://en.wikipedia.org/wiki/Map_(higher-order_function))
|
||||
that is commonly applied to lists (and other structures) in functional
|
||||
programming languages. The function `f` takes the `tf.Tensor` objects that
|
||||
represent a single element in the input, and returns the `tf.Tensor` objects
|
||||
that will represent a single element in the new dataset. Its implementation uses
|
||||
standard TensorFlow operations to transform one element into another.
|
||||
|
||||
This section covers common examples of how to use `Dataset.map()`.
|
||||
|
||||
#### Parsing `tf.Example` protocol buffer messages
|
||||
|
||||
Many input pipelines extract `tf.train.Example` protocol buffer messages from a
|
||||
TFRecord-format file (written, for example, using
|
||||
`tf.python_io.TFRecordWriter`). Each `tf.train.Example` record contains one or
|
||||
more "features", and the input pipeline typically converts these features into
|
||||
tensors.
|
||||
|
||||
```python
|
||||
# Transforms a scalar string `example_proto` into a pair of a scalar string and
|
||||
# a scalar integer, representing an image and its label, respectively.
|
||||
def _parse_function(example_proto):
|
||||
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
|
||||
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
|
||||
parsed_features = tf.parse_single_example(example_proto, features)
|
||||
return parsed_features["image"], parsed_features["label"]
|
||||
|
||||
# Creates a dataset that reads all of the examples from two files, and extracts
|
||||
# the image and label features.
|
||||
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
|
||||
dataset = tf.contrib.data.TFRecordDataset(filenames)
|
||||
dataset = dataset.map(_parse_function)
|
||||
```
|
||||
|
||||
#### Decoding image data and resizing it
|
||||
|
||||
When training a neural network on real-world image data, it is often necessary
|
||||
to convert images of different sizes to a common size, so that they may be
|
||||
batched into a fixed size.
|
||||
|
||||
```python
|
||||
# Reads an image from a file, decodes it into a dense tensor, and resizes it
|
||||
# to a fixed shape.
|
||||
def _parse_function(filename, label):
|
||||
image_string = tf.read_file(filename)
|
||||
image_decoded = tf.image.decode_image(filename)
|
||||
image_resized = tf.image.resize_images(image_decoded, [28, 28])
|
||||
return image_resized, label
|
||||
|
||||
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
|
||||
labels = [0, 37, 29, 1, ...]
|
||||
|
||||
dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels))
|
||||
dataset = dataset.map(_parse_function)
|
||||
```
|
||||
|
||||
#### Applying arbitrary Python logic with `tf.py_func()`
|
||||
|
||||
For performance reasons, we encourage you to use TensorFlow operations for
|
||||
preprocessing your data whenever possible. However, it is sometimes useful to
|
||||
be able to call upon external Python libraries when parsing your input data,
|
||||
and you can do this by invoking the `tf.py_func()` operation in a
|
||||
`Dataset.map()` transformation.
|
||||
|
||||
```python
|
||||
import cv2
|
||||
|
||||
# Use a custom OpenCV function to read the image, instead of the standard
|
||||
# TensorFlow `tf.read_file()` operation.
|
||||
def _read_py_function(filename, label):
|
||||
image_decoded = cv2.imread(image_string, cv2.IMREAD_GRAYSCALE)
|
||||
return image_decoded, label
|
||||
|
||||
# Use standard TensorFlow operations to resize the image to a fixed shape.
|
||||
def _resize_function(image_decoded, label):
|
||||
image_decoded.set_shape([None, None, None])
|
||||
image_resized = tf.image.resize_images(image_decoded, [28, 28])
|
||||
return image_resized, label
|
||||
|
||||
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
|
||||
labels = [0, 37, 29, 1, ...]
|
||||
|
||||
dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels))
|
||||
dataset = dataset.map(
|
||||
lambda filename, label: tf.py_func(
|
||||
_read_py_function, [filename, label], [tf.uint8, label.dtype]))
|
||||
dataset = dataset.map(_resize_function)
|
||||
```
|
||||
|
||||
<!--
|
||||
TODO(mrry): Add this section.
|
||||
|
||||
#### Handling text data with unusual sizes
|
||||
-->
|
||||
|
||||
### Batching dataset elements
|
||||
|
||||
#### Simple batching
|
||||
|
||||
The simplest form of batching stacks `n` consecutive elements of a dataset into
|
||||
a single element. The `Dataset.batch()` transformation does exactly this, with
|
||||
the same constraints as the `tf.stack()` operator, applied to each component
|
||||
of the elements: i.e. for each component *i*, all elements must have a tensor
|
||||
of the exact same shape.
|
||||
|
||||
```python
|
||||
inc_dataset = tf.contrib.data.Dataset.range(100)
|
||||
dec_dataset = tf.contrib.data.Dataset.range(0, -100, -1)
|
||||
dataset = tf.contrib.data.Dataset.zip((inc_dataset, dec_dataset))
|
||||
batched_dataset = dataset.batch(4)
|
||||
|
||||
iterator = batched_dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
|
||||
print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
|
||||
print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
|
||||
```
|
||||
|
||||
#### Batching tensors with padding
|
||||
|
||||
The above recipe works for tensors that all have the same size. However, many
|
||||
models (e.g. sequence models) work with input data that can have varying size
|
||||
(e.g. sequences of different lengths). To handle this case, the
|
||||
`Dataset.padded_batch()` transformation enables you to batch tensors of
|
||||
different shape by specifying one or more dimensions in which they may be
|
||||
padded.
|
||||
|
||||
```python
|
||||
dataset = tf.contrib.data.Dataset.range(100)
|
||||
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
|
||||
dataset = dataset.padded_batch(4, padded_shapes=[None])
|
||||
|
||||
iterator = batched_dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
|
||||
print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0],
|
||||
# [5, 5, 5, 5, 5, 0, 0],
|
||||
# [6, 6, 6, 6, 6, 6, 0],
|
||||
# [7, 7, 7, 7, 7, 7, 7]]
|
||||
```
|
||||
|
||||
The `Dataset.padded_batch()` transformation allows you to set different padding
|
||||
for each dimension of each component, and it may be variable-length (signified
|
||||
by `None` in the example above) or constant-length. It is also possible to
|
||||
override the padding value, which defaults to 0.
|
||||
|
||||
<!--
|
||||
TODO(mrry): Add this section.
|
||||
|
||||
#### Dense ragged -> tf.SparseTensor
|
||||
-->
|
||||
|
||||
### Training workflows
|
||||
|
||||
#### Processing multiple epochs
|
||||
|
||||
The `Dataset` API offers two main ways to process multiple epochs of the same
|
||||
data.
|
||||
|
||||
The simplest way to iterate over a dataset in multiple epochs is to use the
|
||||
`Dataset.repeat()` transformation. For example, to create a dataset that repeats
|
||||
its input for 10 epochs:
|
||||
|
||||
```python
|
||||
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
|
||||
dataset = tf.contrib.data.TFRecordDataset(filenames)
|
||||
dataset = dataset.map(...)
|
||||
dataset = dataset.repeat(10)
|
||||
dataset = dataset.batch(32)
|
||||
```
|
||||
|
||||
Applying the `Dataset.repeat()` transformation with no arguments will repeat
|
||||
the input indefinitely. The `Dataset.repeat()` transformation concatenates its
|
||||
arguments without signaling the end of one epoch and the beginning of the next
|
||||
epoch.
|
||||
|
||||
If you want to receive a signal at the end of each epoch, you can write a
|
||||
training loop that catches the `tf.errors.OutOfRangeError` at the end of a
|
||||
dataset. At that point you might collect some statistics (e.g. the validation
|
||||
error) for the epoch.
|
||||
|
||||
```python
|
||||
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
|
||||
dataset = tf.contrib.data.TFRecordDataset(filenames)
|
||||
dataset = dataset.map(...)
|
||||
dataset = dataset.batch(32)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
# Compute for 100 epochs.
|
||||
for _ in range(100):
|
||||
sess.run(iterator.initializer)
|
||||
while True:
|
||||
try:
|
||||
sess.run(next_element)
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
# [Perform end-of-epoch calculations here.]
|
||||
```
|
||||
|
||||
#### Randomly shuffling input data
|
||||
|
||||
The `Dataset.shuffle()` transformation randomly shuffles the input dataset
|
||||
using a similar algorithm to `tf.RandomShuffleQueue`: it maintains a fixed-size
|
||||
buffer and chooses the next element uniformly at random from that buffer.
|
||||
|
||||
```python
|
||||
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
|
||||
dataset = tf.contrib.data.TFRecordDataset(filenames)
|
||||
dataset = dataset.map(...)
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.shuffle(buffer_size=10000)
|
||||
dataset = dataset.batch(32)
|
||||
```
|
42
tensorflow/contrib/data/__init__.py
Normal file
42
tensorflow/contrib/data/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""`tf.contrib.data.Dataset` API for input pipelines.
|
||||
|
||||
@@Dataset
|
||||
@@Iterator
|
||||
@@TFRecordDataset
|
||||
@@FixedLengthRecordDataset
|
||||
@@TextLineDataset
|
||||
|
||||
@@read_batch_features
|
||||
@@rejection_resample
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import Iterator
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import TextLineDataset
|
||||
from tensorflow.contrib.data.python.ops.dataset_ops import TFRecordDataset
|
||||
# pylint: enable=unused-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
41
tensorflow/contrib/data/python/framework/BUILD
Normal file
41
tensorflow/contrib/data/python/framework/BUILD
Normal file
@ -0,0 +1,41 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "function",
|
||||
srcs = ["function.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "function_test",
|
||||
size = "medium",
|
||||
srcs = ["function_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":function",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
267
tensorflow/contrib/data/python/framework/function.py
Normal file
267
tensorflow/contrib/data/python/framework/function.py
Normal file
@ -0,0 +1,267 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""An experimental fork of the Python TensorFlow-function library.
|
||||
|
||||
NOTE: functions are currently experimental and subject to change!
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
# NOTE(mrry): This is an experimental extension of a core class that wasn't
|
||||
# designed to be extended, so we disable protected access checks for the
|
||||
# whole file.
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
class _ExperimentalFuncGraph(function._FuncGraph):
|
||||
"""A helper for construction a function (supporting capture-by-value).
|
||||
|
||||
_ExperimentalFuncGraph overrides ops.Graph's create_op() so that we can keep
|
||||
track of every inputs into every op created inside the function. If
|
||||
any input is from other graphs, we keep track of it in self.capture
|
||||
and substitue the input with a place holder.
|
||||
|
||||
Each captured input's corresponding place holder is converted into a
|
||||
function argument and the caller passes in the captured tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, capture_by_value, *args, **kwargs):
|
||||
super(_ExperimentalFuncGraph, self).__init__(*args, **kwargs)
|
||||
self._capture_by_value = capture_by_value
|
||||
self._building_function = True
|
||||
self._outer_graph = ops.get_default_graph()
|
||||
self._vscope = vs.get_variable_scope()
|
||||
self._old_custom_getter = self._vscope.custom_getter
|
||||
self._captured = {}
|
||||
self.extra_inputs = []
|
||||
self.extra_args = []
|
||||
self.extra_vars = []
|
||||
|
||||
def create_op(self, op_type, inputs, data_types, **kwargs):
|
||||
for i, x in enumerate(inputs):
|
||||
if x.graph is not self:
|
||||
# Referring to a tensor from other graph.
|
||||
if x in self._captured:
|
||||
# Captured already.
|
||||
inputs[i] = self._captured[x]
|
||||
elif self._capture_by_value:
|
||||
inputs[i] = self._add_tensor_and_parents(x)
|
||||
else:
|
||||
# Substitute with a placeholder.
|
||||
self.extra_inputs.append(x)
|
||||
ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
|
||||
# pylint: disable=protected-access
|
||||
ph._handle_shape = x._handle_shape
|
||||
ph._handle_dtype = x._handle_dtype
|
||||
# pylint: enable=protected-access
|
||||
inputs[i] = ph
|
||||
self._captured[x] = ph
|
||||
self.extra_args.append(ph)
|
||||
return super(_ExperimentalFuncGraph, self).create_op(op_type, inputs,
|
||||
data_types, **kwargs)
|
||||
|
||||
def _add_tensor_and_parents(self, tensor):
|
||||
op = self._add_op_and_parents(tensor.op)
|
||||
return op.outputs[tensor.value_index]
|
||||
|
||||
def _add_op_and_parents(self, op):
|
||||
op_def = function._get_op_def(op)
|
||||
if op_def.is_stateful:
|
||||
raise ValueError("Cannot capture a stateful node by value.")
|
||||
elif op.type in ("Placeholder", "PlaceholderV2"):
|
||||
raise ValueError("Cannot capture a placeholder by value.")
|
||||
|
||||
captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
|
||||
|
||||
captured_op = self.create_op(op.type, captured_inputs,
|
||||
[o.dtype for o in op.outputs],
|
||||
name=op.name, attrs=op.node_def.attr,
|
||||
op_def=op_def)
|
||||
|
||||
for t, captured_t in zip(op.outputs, captured_op.outputs):
|
||||
self._captured[t] = captured_t
|
||||
|
||||
return captured_op
|
||||
|
||||
|
||||
class _ExperimentalDefinedFunction(function._DefinedFunction):
|
||||
"""Overrides _DefinedFunction with support for capture-by-value."""
|
||||
|
||||
def __init__(self,
|
||||
func,
|
||||
argnames,
|
||||
input_types,
|
||||
func_name=None,
|
||||
grad_func=None,
|
||||
python_grad_func=None,
|
||||
out_names=None,
|
||||
shape_func=None,
|
||||
capture_by_value=False,
|
||||
**kwargs):
|
||||
"""Creates an _ExperimentalDefinedFunction.
|
||||
|
||||
Args:
|
||||
func: A python callable which constructs a tf function body.
|
||||
argnames: A list of strings for function argument names.
|
||||
input_types: The function's argument types. Can be a tuple, list of
|
||||
tf data types.
|
||||
func_name: The function name. Defaults to None, in which derives from
|
||||
'func'.
|
||||
grad_func: This function's gradient function, if not None. Defaults
|
||||
to None.
|
||||
python_grad_func: A python callable implementing the gradient of
|
||||
the function python-side.
|
||||
out_names: An optional list of strings for the function return value
|
||||
names.
|
||||
shape_func: An optional function mapping an op to a list of static
|
||||
output shapes.
|
||||
capture_by_value: Boolean (defaults to False). If True, captured values
|
||||
will be copied into the function body.
|
||||
**kwargs: The keyword arguments. **kwargs is passed to every call
|
||||
site of this function.
|
||||
|
||||
Raises:
|
||||
ValueError: The function definition is invalid.
|
||||
"""
|
||||
super(_ExperimentalDefinedFunction, self).__init__(
|
||||
func, argnames, input_types, func_name, grad_func, python_grad_func,
|
||||
out_names, shape_func, **kwargs)
|
||||
self._capture_by_value = capture_by_value
|
||||
|
||||
def _create_definition_if_needed(self):
|
||||
"""Creates the function definition if it's not created yet."""
|
||||
|
||||
if self._definition is not None:
|
||||
return
|
||||
|
||||
# Create the func_def object.
|
||||
temp_graph = _ExperimentalFuncGraph(capture_by_value=self._capture_by_value)
|
||||
with temp_graph.as_default():
|
||||
# List of placeholders for the function_def.
|
||||
inputs = []
|
||||
for (argname, argtype) in self._args:
|
||||
argholder = array_ops.placeholder(argtype, name=argname)
|
||||
inputs.append(argholder)
|
||||
# Call func and gather the output tensors.
|
||||
with vs.variable_scope("", custom_getter=temp_graph.getvar):
|
||||
outputs = self._func(*inputs)
|
||||
# If func only returned one value, make it a tuple.
|
||||
if not isinstance(outputs, (list, tuple)):
|
||||
outputs = (outputs,)
|
||||
if any([_ is None for _ in outputs]):
|
||||
raise ValueError("Function can not return None.")
|
||||
# Ensures each output is a Tensor.
|
||||
outputs = [ops.convert_to_tensor(_) for _ in outputs]
|
||||
self._extra_inputs = temp_graph.extra_inputs
|
||||
inputs.extend(temp_graph.extra_args)
|
||||
self._sub_functions = temp_graph._functions
|
||||
|
||||
# Build the FunctionDef
|
||||
self._definition = function._graph_to_function_def(
|
||||
temp_graph, temp_graph.get_operations(), inputs, outputs,
|
||||
out_names=self._out_names)
|
||||
|
||||
# Extra kwargs are treated as attrs on the function def.
|
||||
sig_pre_func_name = self._func_name or function._get_func_name(self._func)
|
||||
kwargs_attr = function._parse_kwargs_as_attrs(
|
||||
sig_pre_func_name, **self._extra_kwargs)
|
||||
for k in kwargs_attr:
|
||||
self._definition.attr[k].CopyFrom(kwargs_attr[k])
|
||||
|
||||
# Hash the definition and its dependencies.
|
||||
self._hash_str = self._create_hash_str(
|
||||
self._definition.signature.input_arg,
|
||||
self._definition.signature.output_arg,
|
||||
self._definition.node_def)
|
||||
|
||||
# Finally, we decide the function name to use. If not specified,
|
||||
# make up something which is almost certainly unique (but deterministic).
|
||||
if not self._func_name:
|
||||
self._func_name = "_".join([function._get_func_name(self._func),
|
||||
self._hash_str])
|
||||
self._definition.signature.name = self._func_name
|
||||
if self._func.__doc__:
|
||||
self._definition.signature.description = self._func.__doc__
|
||||
|
||||
|
||||
class Defun(function.Defun):
|
||||
"""Experimental version of Defun supporting capture-by-value."""
|
||||
|
||||
def __init__(self, *input_types, **kwargs):
|
||||
"""Create an experimental `Defun` decorator.
|
||||
|
||||
Args:
|
||||
*input_types: A list of `tf.DType`
|
||||
**kwargs: Optional keyword arguments (see `function.Defun`) plus:
|
||||
capture_by_value - Boolean (defaults to False). If True, captured values
|
||||
will be copied into the function body.
|
||||
"""
|
||||
super(Defun, self).__init__(*input_types, **kwargs)
|
||||
|
||||
def __call__(self, func):
|
||||
# Various sanity checks on the callable func.
|
||||
if not callable(func):
|
||||
raise ValueError("func %s must be callable" % func)
|
||||
|
||||
# Func should not use kwargs and defaults.
|
||||
argspec = tf_inspect.getargspec(func)
|
||||
if argspec.keywords or argspec.defaults:
|
||||
raise ValueError("Functions with argument defaults or keyword "
|
||||
"arguments are not supported.")
|
||||
|
||||
# Computes how many arguments 'func' has.
|
||||
min_args = len(argspec.args)
|
||||
max_args = min_args
|
||||
if argspec.varargs:
|
||||
max_args = 1000000
|
||||
argnames = argspec.args
|
||||
if tf_inspect.ismethod(func):
|
||||
# 1st argument is the "class" type.
|
||||
min_args -= 1
|
||||
argnames = argnames[1:]
|
||||
|
||||
if self._input_types:
|
||||
# If Defun is given a list of types for the inputs, the number
|
||||
# of input types should be compatible with 'func'.
|
||||
num = len(self._input_types)
|
||||
if num < min_args or num > max_args:
|
||||
raise ValueError(
|
||||
"The function has fewer arguments than the number of specified "
|
||||
"input types.")
|
||||
return _ExperimentalDefinedFunction(
|
||||
func, argnames, self._input_types, self._func_name, self._grad_func,
|
||||
self._python_grad_func, out_names=self._out_names,
|
||||
**self._extra_kwargs)
|
||||
|
||||
# 'func' expects no arguments and input types is an empty list.
|
||||
if min_args == 0 and max_args == 0:
|
||||
return _ExperimentalDefinedFunction(
|
||||
func, [], [], self._func_name, self._grad_func,
|
||||
self._python_grad_func, out_names=self._out_names,
|
||||
**self._extra_kwargs)
|
||||
|
||||
# Input types are unknown. It's an overloaded function and hence
|
||||
# its definition needs to be deferred until it's called.
|
||||
return function._OverloadedFunction(
|
||||
func, argnames, self._func_name, self._grad_func,
|
||||
self._python_grad_func, out_names=self._out_names, **self._extra_kwargs)
|
59
tensorflow/contrib/data/python/framework/function_test.py
Normal file
59
tensorflow/contrib/data/python/framework/function_test.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for experimental capture-by-value feature in TF functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.data.python.framework import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FunctionTest(test.TestCase):
|
||||
|
||||
def testCaptureByValue(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
w = constant_op.constant([[1.0]])
|
||||
b = constant_op.constant([2.0])
|
||||
|
||||
# Foo() captures w and b.
|
||||
@function.Defun(dtypes.float32, capture_by_value=True)
|
||||
def Foo(x):
|
||||
|
||||
# Plus() captures b.
|
||||
@function.Defun(dtypes.float32, capture_by_value=True)
|
||||
def Plus(y):
|
||||
return y + b
|
||||
|
||||
self.assertEqual(0, len(Plus.captured_inputs))
|
||||
|
||||
return Plus(math_ops.matmul(w, x))
|
||||
|
||||
y = Foo(constant_op.constant([[10.]]))
|
||||
|
||||
self.assertEqual(0, len(Foo.captured_inputs))
|
||||
|
||||
with self.test_session(graph=g):
|
||||
self.assertAllEqual(y.eval(), [[12.0]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
219
tensorflow/contrib/data/python/kernel_tests/BUILD
Normal file
219
tensorflow/contrib/data/python/kernel_tests/BUILD
Normal file
@ -0,0 +1,219 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
py_test(
|
||||
name = "iterator_ops_test",
|
||||
size = "small",
|
||||
srcs = ["iterator_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "batch_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["batch_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:string_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "bucketing_test",
|
||||
size = "small",
|
||||
srcs = ["bucketing_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:string_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_constructor_op_test",
|
||||
size = "small",
|
||||
srcs = ["dataset_constructor_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "filter_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["filter_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "flat_map_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["flat_map_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "map_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["map_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:string_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "range_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["range_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "reader_dataset_ops_test",
|
||||
size = "small",
|
||||
srcs = ["reader_dataset_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "resample_test",
|
||||
size = "medium",
|
||||
srcs = ["resample_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sequence_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["sequence_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "shuffle_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["shuffle_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "zip_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["zip_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -0,0 +1,276 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class BatchDatasetTest(test.TestCase):
|
||||
|
||||
def testBatchDataset(self):
|
||||
"""Test an dataset that maps a TF function across its input elements."""
|
||||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
||||
# RepeatDataset(count) -> BatchDataset(batch_size).
|
||||
components = [np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7)]
|
||||
|
||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
|
||||
.repeat(count).batch(batch_size).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
|
||||
[t.shape.as_list() for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Batch of a finite input, where the batch_size divides the
|
||||
# total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
|
||||
num_batches = (28 * 7) // 14
|
||||
for i in range(num_batches):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(14):
|
||||
self.assertAllEqual(component[(i*14 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Batch of a finite input, where the batch_size does not
|
||||
# divide the total number of elements.
|
||||
sess.run(init_op, feed_dict={count: 14, batch_size: 8})
|
||||
|
||||
# We expect (num_batches - 1) full-sized batches.
|
||||
num_batches = int(math.ceil((14 * 7) / 8))
|
||||
for i in range(num_batches - 1):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range(8):
|
||||
self.assertAllEqual(component[(i*8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
for j in range((14 * 7) % 8):
|
||||
self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2,
|
||||
result_component[j])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Batch of an empty input should fail straight away.
|
||||
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Empty batch should be an initialization time error.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
|
||||
|
||||
def testPaddedBatchDataset(self):
|
||||
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
|
||||
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens)
|
||||
.map(lambda x: array_ops.fill([x], x)).padded_batch(
|
||||
4,
|
||||
padded_shapes=padded_shape).make_initializable_iterator())
|
||||
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Test with random sequence lengths, and max padding.
|
||||
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
|
||||
sess.run(init_op, feed_dict={padded_shape: [-1],
|
||||
seq_lens: random_seq_lens})
|
||||
for i in range(8):
|
||||
result = sess.run(get_next)
|
||||
padded_len = np.max(result)
|
||||
self.assertEqual((4, padded_len), result.shape)
|
||||
for j in range(4):
|
||||
seq_len = random_seq_lens[(i*4)+j]
|
||||
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
|
||||
self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test with random sequence lengths, and constant padding.
|
||||
sess.run(init_op, feed_dict={padded_shape: [25],
|
||||
seq_lens: random_seq_lens})
|
||||
for i in range(8):
|
||||
result = sess.run(get_next)
|
||||
self.assertEqual((4, 25), result.shape)
|
||||
for j in range(4):
|
||||
seq_len = random_seq_lens[(i*4)+j]
|
||||
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
|
||||
self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test correct handling of empty tensors.
|
||||
sess.run(init_op, feed_dict={padded_shape: [-1],
|
||||
seq_lens: [0, 0, 0, 0]})
|
||||
result = sess.run(get_next)
|
||||
self.assertAllEqual([[], [], [], []], result)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test error handling with constant sequence lengths, and
|
||||
# too-short padding.
|
||||
sess.run(init_op, feed_dict={padded_shape: [5],
|
||||
seq_lens: [6, 5, 5, 5]})
|
||||
with self.assertRaises(errors.DataLossError):
|
||||
result = sess.run(get_next)
|
||||
|
||||
def testPaddedBatchDatasetNonDefaultPadding(self):
|
||||
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
|
||||
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
|
||||
|
||||
def fill_tuple(x):
|
||||
filled = array_ops.fill([x], x)
|
||||
return (filled, string_ops.as_string(filled))
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
|
||||
.padded_batch(
|
||||
4,
|
||||
padded_shapes=(padded_shape, padded_shape),
|
||||
padding_values=(-1, "<end>")).make_initializable_iterator())
|
||||
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Test with random sequence lengths, and max padding.
|
||||
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
|
||||
sess.run(init_op, feed_dict={padded_shape: [-1],
|
||||
seq_lens: random_seq_lens})
|
||||
for i in range(8):
|
||||
result = sess.run(get_next)
|
||||
padded_len = np.max(result[0])
|
||||
self.assertEqual((4, padded_len), result[0].shape)
|
||||
self.assertEqual((4, padded_len), result[1].shape)
|
||||
for j in range(4):
|
||||
seq_len = random_seq_lens[(i*4)+j]
|
||||
self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
|
||||
self.assertAllEqual(result[0][j, seq_len:],
|
||||
[-1] * (padded_len - seq_len))
|
||||
self.assertAllEqual(result[1][j, :seq_len],
|
||||
[compat.as_bytes(str(seq_len))] * seq_len)
|
||||
self.assertAllEqual(result[1][j, seq_len:],
|
||||
[b"<end>"] * (padded_len - seq_len))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testPaddedBatchDatasetShapeSpecifications(self):
|
||||
int_placeholder = array_ops.placeholder(dtypes.int32)
|
||||
float_placeholder = array_ops.placeholder(dtypes.float32)
|
||||
string_placeholder = array_ops.placeholder(dtypes.string)
|
||||
input_dataset = dataset_ops.Dataset.from_tensors(
|
||||
(int_placeholder, float_placeholder, string_placeholder))
|
||||
|
||||
# Test different ways of specifying the `padded_shapes` argument.
|
||||
dynamic_padding_from_tensor_shapes = input_dataset.padded_batch(
|
||||
32,
|
||||
padded_shapes=(tensor_shape.TensorShape([None]),
|
||||
tensor_shape.TensorShape([None, None]),
|
||||
tensor_shape.TensorShape([37])))
|
||||
dynamic_padding_from_lists = input_dataset.padded_batch(
|
||||
32, padded_shapes=([None], [None, None], [37]))
|
||||
dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch(
|
||||
32, padded_shapes=([-1], [-1, -1], [37]))
|
||||
dynamic_padding_from_tensors = input_dataset.padded_batch(
|
||||
32,
|
||||
padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64),
|
||||
constant_op.constant([-1, -1], dtype=dtypes.int64),
|
||||
constant_op.constant([37], dtype=dtypes.int64)))
|
||||
|
||||
for dataset in [dynamic_padding_from_tensor_shapes,
|
||||
dynamic_padding_from_lists,
|
||||
dynamic_padding_from_lists_with_minus_one,
|
||||
dynamic_padding_from_tensors]:
|
||||
self.assertEqual([None, None], dataset.output_shapes[0].as_list())
|
||||
self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
|
||||
self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
|
||||
|
||||
def testDenseToSparseBatchDataset(self):
|
||||
components = np.random.randint(12, size=(100,)).astype(np.int32)
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([x], x)).dense_to_sparse_batch(
|
||||
4, [12]).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
for start in range(0, len(components), 4):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(
|
||||
[[i, j] for i, c in enumerate(components[start:start+4])
|
||||
for j in range(c)], results.indices)
|
||||
self.assertAllEqual(
|
||||
[c for c in components[start:start+4] for _ in range(c)],
|
||||
results.values)
|
||||
self.assertAllEqual(
|
||||
[min(4, len(components) - start), 12], results.dense_shape)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testDenseToSparseBatchDatasetShapeErrors(self):
|
||||
input_tensor = array_ops.placeholder(dtypes.int32)
|
||||
iterator = (dataset_ops.Dataset.from_tensors(input_tensor)
|
||||
.dense_to_sparse_batch(4, [12]).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Initialize with an input tensor of incompatible rank.
|
||||
sess.run(init_op, feed_dict={input_tensor: [[1]]})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"incompatible with the row shape"):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with an input tensor that is larger than `row_shape`.
|
||||
sess.run(init_op, feed_dict={input_tensor: range(13)})
|
||||
with self.assertRaisesRegexp(errors.DataLossError,
|
||||
"larger than the row shape"):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
292
tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
Normal file
292
tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
Normal file
@ -0,0 +1,292 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BucketingTest(test.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
|
||||
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
result = sess.run(get_next)
|
||||
self.assertTrue(
|
||||
all(x % 2 == 0 for x in result) or all(x % 2 == 1)
|
||||
for x in result)
|
||||
counts.append(result.shape[0])
|
||||
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
num_full_batches = len([c for c in counts if c == 4])
|
||||
self.assertGreaterEqual(num_full_batches, 23)
|
||||
self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
|
||||
|
||||
def testImmediateOutput(self):
|
||||
components = np.array(
|
||||
[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(-1)
|
||||
.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
# The input is infinite, so this test demonstrates that:
|
||||
# 1. We produce output without having to consume the entire input,
|
||||
# 2. Different buckets can produce output at different rates, and
|
||||
# 3. For deterministic input, the output is deterministic.
|
||||
for _ in range(3):
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
|
||||
def testSmallGroups(self):
|
||||
components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
|
||||
# The small outputs at the end are deterministically produced in key
|
||||
# order.
|
||||
self.assertAllEqual([0, 0, 0], sess.run(get_next))
|
||||
self.assertAllEqual([1], sess.run(get_next))
|
||||
|
||||
def testReduceFuncError(self):
|
||||
components = np.random.randint(100, size=(200,)).astype(np.int64)
|
||||
|
||||
def reduce_func(_, xs):
|
||||
# Introduce an incorrect padded shape that cannot (currently) be
|
||||
# detected at graph construction time.
|
||||
return xs.padded_batch(
|
||||
4,
|
||||
padded_shapes=(tensor_shape.TensorShape([]),
|
||||
constant_op.constant([5], dtype=dtypes.int64) * -1))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: (x, ops.convert_to_tensor([x * x])))
|
||||
.group_by_window(lambda x, _: x % 2, reduce_func, 32))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testConsumeWindowDatasetMoreThanOnce(self):
|
||||
components = np.random.randint(50, size=(200,)).astype(np.int64)
|
||||
|
||||
def reduce_func(key, window):
|
||||
# Apply two different kinds of padding to the input: tight
|
||||
# padding, and quantized (to a multiple of 10) padding.
|
||||
return dataset_ops.Dataset.zip((window.padded_batch(
|
||||
4,
|
||||
padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch(
|
||||
4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),))
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
|
||||
.group_by_window(
|
||||
lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
|
||||
reduce_func, 4))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
counts = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
tight_result, multiple_of_10_result = sess.run(get_next)
|
||||
self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
|
||||
self.assertAllEqual(tight_result,
|
||||
multiple_of_10_result[:, :tight_result.shape[1]])
|
||||
counts.append(tight_result.shape[0])
|
||||
self.assertEqual(len(components), sum(counts))
|
||||
|
||||
|
||||
# NOTE(mrry): These tests are based on the tests in
|
||||
# bucket_ops_test.py. Currently, different batch sizes for each key
|
||||
# are not supported, although this would be possible to add to
|
||||
# `Dataset.group_by_window()`.
|
||||
class BucketTest(test.TestCase):
|
||||
|
||||
def _dynamicPad(self, bucket, window, window_size):
|
||||
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
|
||||
# generic form of padded_batch that pads every component
|
||||
# dynamically and does not rely on static shape information about
|
||||
# the arguments.
|
||||
return dataset_ops.Dataset.zip(
|
||||
(dataset_ops.Dataset.from_tensors(bucket), window.padded_batch(
|
||||
32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape([None]),
|
||||
tensor_shape.TensorShape([3])))))
|
||||
|
||||
def testSingleBucket(self):
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda x, y, z: 0, lambda k, bucket: self._dynamicPad(k, bucket, 32),
|
||||
32)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
which_bucket, bucketed_values = sess.run(get_next)
|
||||
|
||||
self.assertEqual(0, which_bucket)
|
||||
|
||||
expected_scalar_int = np.arange(32, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
|
||||
for i in range(32):
|
||||
expected_unk_int64[i, :i] = i
|
||||
expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values[2])
|
||||
|
||||
def testEvenOddBuckets(self):
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches (one containing even values, one containing odds)
|
||||
which_bucket_even, bucketed_values_even = sess.run(get_next)
|
||||
which_bucket_odd, bucketed_values_odd = sess.run(get_next)
|
||||
|
||||
# Count number of bucket_tensors.
|
||||
self.assertEqual(3, len(bucketed_values_even))
|
||||
self.assertEqual(3, len(bucketed_values_odd))
|
||||
|
||||
# Ensure bucket 0 was used for all minibatch entries.
|
||||
self.assertAllEqual(0, which_bucket_even)
|
||||
self.assertAllEqual(1, which_bucket_odd)
|
||||
|
||||
# Test the first bucket outputted, the events starting at 0
|
||||
expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2 * i] = 2 * i
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
|
||||
|
||||
# Test the second bucket outputted, the odds starting at 1
|
||||
expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
|
||||
expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
|
||||
for i in range(0, 32):
|
||||
expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
|
||||
expected_vec3_str = np.vstack(
|
||||
3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
|
||||
|
||||
self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
|
||||
self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
|
||||
self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
|
||||
|
||||
def testEvenOddBucketsFilterOutAllOdd(self):
|
||||
def _map_fn(v):
|
||||
return (v, array_ops.fill([v], v),
|
||||
array_ops.fill([3], string_ops.as_string(v)))
|
||||
|
||||
input_dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
|
||||
.filter(lambda x, y, z: math_ops.equal(x % 2, 0)))
|
||||
|
||||
bucketed_dataset = input_dataset.group_by_window(
|
||||
lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
|
||||
lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset)
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
|
||||
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
|
||||
which_bucket0, bucketed_values_even0 = sess.run(get_next)
|
||||
which_bucket1, bucketed_values_even1 = sess.run(get_next)
|
||||
|
||||
# Ensure that bucket 1 was completely filtered out
|
||||
self.assertAllEqual(0, which_bucket0)
|
||||
self.assertAllEqual(0, which_bucket1)
|
||||
self.assertAllEqual(
|
||||
np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0[0])
|
||||
self.assertAllEqual(
|
||||
np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,239 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class DatasetConstructorTest(test.TestCase):
|
||||
|
||||
def testTensorDataset(self):
|
||||
"""Test an dataset that represents a single tuple of tensors."""
|
||||
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensors(components)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testTensorSliceDataset(self):
|
||||
"""Test an dataset that represents the slices from a tuple of tensors."""
|
||||
components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
|
||||
np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array([37.0, 38.0, 39.0, 40.0])
|
||||
]
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in range(4):
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSparseTensorSliceDataset(self):
|
||||
"""Test a dataset based on slices of a `tf.SparseTensor`."""
|
||||
st = array_ops.sparse_placeholder(dtypes.float64)
|
||||
iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = sparse_tensor.SparseTensor(*iterator.get_next())
|
||||
|
||||
with self.test_session() as sess:
|
||||
slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
|
||||
|
||||
# Test with sparse tensor in the appropriate order.
|
||||
indices = np.array(
|
||||
[[i, j] for i in range(len(slices)) for j in range(len(slices[i]))])
|
||||
values = np.array([val for s in slices for val in s])
|
||||
dense_shape = np.array([len(slices), max(len(s) for s in slices) + 1])
|
||||
sparse_feed = sparse_tensor.SparseTensorValue(indices, values,
|
||||
dense_shape)
|
||||
sess.run(init_op, feed_dict={st: sparse_feed})
|
||||
for i, s in enumerate(slices):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(s, results.values)
|
||||
expected_indices = np.array(
|
||||
[[j] for j in range(len(slices[i]))]).reshape([-1, 1])
|
||||
self.assertAllEqual(expected_indices, results.indices)
|
||||
self.assertAllEqual(dense_shape[1:], results.dense_shape)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test with sparse tensor in the reverse order, which is not
|
||||
# currently supported.
|
||||
reverse_order_indices = indices[::-1, :]
|
||||
reverse_order_values = values[::-1]
|
||||
sparse_feed = sparse_tensor.SparseTensorValue(
|
||||
reverse_order_indices, reverse_order_values, dense_shape)
|
||||
with self.assertRaises(errors.UnimplementedError):
|
||||
sess.run(init_op, feed_dict={st: sparse_feed})
|
||||
|
||||
# Test with an empty sparse tensor.
|
||||
empty_indices = np.empty((0, 4), dtype=np.int64)
|
||||
empty_values = np.empty((0,), dtype=np.float64)
|
||||
empty_dense_shape = [0, 4, 37, 9]
|
||||
sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
|
||||
empty_dense_shape)
|
||||
sess.run(init_op, feed_dict={st: sparse_feed})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# pylint: disable=g-long-lambda,unnecessary-lambda
|
||||
def testNestedStructure(self):
|
||||
components = (np.array([1, 2, 3]), (np.array([4., 5.]), np.array([6., 7.])),
|
||||
np.array([8, 9, 10]))
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(components)
|
||||
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
|
||||
dtypes.int64), dataset.output_types)
|
||||
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
|
||||
|
||||
dataset = dataset.shuffle(10, 10)
|
||||
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
|
||||
dtypes.int64), dataset.output_types)
|
||||
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
|
||||
|
||||
dataset = dataset.repeat(-1)
|
||||
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
|
||||
dtypes.int64), dataset.output_types)
|
||||
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
|
||||
|
||||
dataset = dataset.filter(lambda x, y, z: True)
|
||||
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
|
||||
dtypes.int64), dataset.output_types)
|
||||
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
|
||||
|
||||
dataset = dataset.take(5)
|
||||
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
|
||||
dtypes.int64), dataset.output_types)
|
||||
self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
|
||||
|
||||
dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
|
||||
self.assertEquals(((dtypes.int64, dtypes.int64),
|
||||
(dtypes.float64, dtypes.float64)), dataset.output_types)
|
||||
self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
|
||||
|
||||
dataset = dataset.flat_map(
|
||||
lambda x, y: dataset_ops.Dataset.from_tensors(((x[0], x[1]),
|
||||
(y[0], y[1])))
|
||||
)
|
||||
self.assertEquals(((dtypes.int64, dtypes.int64),
|
||||
(dtypes.float64, dtypes.float64)), dataset.output_types)
|
||||
self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
|
||||
|
||||
dataset = dataset.batch(32)
|
||||
self.assertEquals(((dtypes.int64, dtypes.int64),
|
||||
(dtypes.float64, dtypes.float64)), dataset.output_types)
|
||||
self.assertEquals((([None, 3], [None, 3]), ([None, 2], [None, 2])),
|
||||
nest.pack_sequence_as(dataset.output_shapes, [
|
||||
s.as_list()
|
||||
for s in nest.flatten(dataset.output_shapes)
|
||||
]))
|
||||
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
(w, x), (y, z) = iterator.get_next()
|
||||
self.assertEquals(dtypes.int64, w.dtype)
|
||||
self.assertEquals(dtypes.int64, x.dtype)
|
||||
self.assertEquals(dtypes.float64, y.dtype)
|
||||
self.assertEquals(dtypes.float64, z.dtype)
|
||||
self.assertEquals([None, 3], w.shape.as_list())
|
||||
self.assertEquals([None, 3], x.shape.as_list())
|
||||
self.assertEquals([None, 2], y.shape.as_list())
|
||||
self.assertEquals([None, 2], z.shape.as_list())
|
||||
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
(w, x), (y, z) = iterator.get_next()
|
||||
self.assertEquals(dtypes.int64, w.dtype)
|
||||
self.assertEquals(dtypes.int64, x.dtype)
|
||||
self.assertEquals(dtypes.float64, y.dtype)
|
||||
self.assertEquals(dtypes.float64, z.dtype)
|
||||
self.assertEquals([None, 3], w.shape.as_list())
|
||||
self.assertEquals([None, 3], x.shape.as_list())
|
||||
self.assertEquals([None, 2], y.shape.as_list())
|
||||
self.assertEquals([None, 2], z.shape.as_list())
|
||||
|
||||
# Define a separate set of components with matching leading
|
||||
# dimension for the from-slices constructor.
|
||||
components_for_slices = (np.array([1, 2, 3]), (np.array(
|
||||
[4., 5., 6.]), np.array([7., 8., 9.])), np.array([10, 11, 12]))
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
|
||||
self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
|
||||
dtypes.int64), dataset.output_types)
|
||||
self.assertEquals(([], ([], []), []), dataset.output_shapes)
|
||||
|
||||
def testNonSequenceNestedStructure(self):
|
||||
components = np.array([1, 2, 3])
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(components)
|
||||
self.assertEquals(dtypes.int64, dataset.output_types)
|
||||
self.assertEquals([3], dataset.output_shapes)
|
||||
|
||||
dataset = dataset.filter(
|
||||
lambda x: math_ops.reduce_all(math_ops.equal(x, components)))
|
||||
self.assertEquals(dtypes.int64, dataset.output_types)
|
||||
self.assertEquals([3], dataset.output_shapes)
|
||||
|
||||
dataset = dataset.map(lambda x: array_ops.stack([x, x]))
|
||||
self.assertEquals(dtypes.int64, dataset.output_types)
|
||||
self.assertEquals([2, 3], dataset.output_shapes)
|
||||
|
||||
dataset = dataset.flat_map(
|
||||
lambda x: dataset_ops.Dataset.from_tensor_slices(x))
|
||||
self.assertEquals(dtypes.int64, dataset.output_types)
|
||||
self.assertEquals([3], dataset.output_shapes)
|
||||
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
self.assertEquals(dtypes.int64, get_next.dtype)
|
||||
self.assertEquals([3], get_next.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,77 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FilterDatasetTest(test.TestCase):
|
||||
|
||||
def testFilterDataset(self):
|
||||
components = [
|
||||
np.arange(7, dtype=np.int64),
|
||||
np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
|
||||
7, dtype=np.int64)[:, np.newaxis],
|
||||
np.array(37.0, dtype=np.float64) * np.arange(7)
|
||||
]
|
||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
modulus = array_ops.placeholder(dtypes.int64)
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
|
||||
.repeat(count)
|
||||
.filter(lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Test that we can dynamically feed a different modulus value for each
|
||||
# iterator.
|
||||
def do_test(count_val, modulus_val):
|
||||
sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
|
||||
for _ in range(count_val):
|
||||
for i in [x for x in range(7) if x**2 % modulus_val == 0]:
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
do_test(14, 2)
|
||||
do_test(4, 18)
|
||||
|
||||
# Test an empty dataset.
|
||||
do_test(0, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,108 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
class FlatMapDatasetTest(test.TestCase):
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
def testFlatMapDataset(self):
|
||||
repeats = [1, 2, 3, 4, 5, 0, 1]
|
||||
components = [np.array(repeats, dtype=np.int64)]
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next, = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for i in repeats:
|
||||
for _ in range(i):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testNestedFlatMapDataset(self):
|
||||
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
||||
components = [np.array(repeats, dtype=np.int64)]
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])
|
||||
.flat_map(lambda y: dataset_ops.Dataset.from_tensors([y])
|
||||
.repeat(y))).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next, = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for row in repeats:
|
||||
for i in row:
|
||||
for _ in range(i):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSharedResourceNestedFlatMapDataset(self):
|
||||
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
|
||||
components = [np.array(repeats, dtype=np.int64)]
|
||||
iterator = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])
|
||||
.flat_map(lambda y: dataset_ops.Dataset.from_tensors([y])
|
||||
.repeat(y))).make_initializable_iterator(
|
||||
shared_name="shared_flat_map_iterator"))
|
||||
init_op = iterator.initializer
|
||||
get_next, = iterator.get_next()
|
||||
|
||||
# Create two concurrent sessions that share the same iterator
|
||||
# resource on the same server, and verify that a random
|
||||
# interleaving of `Session.run(get_next)` calls on the two
|
||||
# sessions yields the expected result.
|
||||
server = server_lib.Server.create_local_server()
|
||||
with session.Session(server.target) as sess1:
|
||||
with session.Session(server.target) as sess2:
|
||||
for _ in range(3):
|
||||
sess = random.choice([sess1, sess2])
|
||||
sess.run(init_op)
|
||||
for row in repeats:
|
||||
for i in row:
|
||||
for _ in range(i):
|
||||
sess = random.choice([sess1, sess2])
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess = random.choice([sess1, sess2])
|
||||
sess.run(get_next)
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
252
tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
Normal file
252
tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
Normal file
@ -0,0 +1,252 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
class IteratorTest(test.TestCase):
|
||||
|
||||
def testOneShotIterator(self):
|
||||
components = [np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7)]
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
|
||||
.repeat(14).make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testOneShotIteratorCaptureByValue(self):
|
||||
components = [np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7)]
|
||||
tensor_components = [ops.convert_to_tensor(c) for c in components]
|
||||
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(tensor_components)
|
||||
.map(_map_fn).repeat(14).make_one_shot_iterator())
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testOneShotIteratorInsideContainer(self):
|
||||
components = [np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7)]
|
||||
|
||||
def within_container():
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(_map_fn).repeat(14).make_one_shot_iterator())
|
||||
return iterator.get_next()
|
||||
|
||||
server = server_lib.Server.create_local_server()
|
||||
|
||||
# Create two iterators within unique containers, and run them to
|
||||
# make sure that the resources aren't shared.
|
||||
#
|
||||
# The test below would fail if cname were the same across both
|
||||
# sessions.
|
||||
for i in range(2):
|
||||
with session.Session(server.target) as sess:
|
||||
cname = "iteration%d" % i
|
||||
with ops.container(cname):
|
||||
get_next = within_container()
|
||||
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSimpleSharedResource(self):
|
||||
components = [
|
||||
np.array(1, dtype=np.int64),
|
||||
np.array([1, 2, 3], dtype=np.int64),
|
||||
np.array(37.0, dtype=np.float64)
|
||||
]
|
||||
|
||||
server = server_lib.Server.create_local_server()
|
||||
|
||||
# Create two non-overlapping sessions that share the same iterator
|
||||
# resource on the same server, and verify that an action of the
|
||||
# first session (initializing the iterator) is visible in the
|
||||
# second session.
|
||||
with ops.Graph().as_default():
|
||||
iterator = (dataset_ops.Dataset.from_tensors(components)
|
||||
.map(lambda x, y, z: (x, y, z)).make_initializable_iterator(
|
||||
shared_name="shared_iterator"))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session(server.target) as sess:
|
||||
sess.run(init_op)
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Re-initialize the iterator in the first session.
|
||||
sess.run(init_op)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
# Re-define the iterator manually, without defining any of the
|
||||
# functions in this graph, to ensure that we are not
|
||||
# accidentally redefining functions with the same names in the
|
||||
# new graph.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
shared_name="shared_iterator",
|
||||
output_types=[dtypes.int64, dtypes.int64, dtypes.float64],
|
||||
output_shapes=[[], [3], []])
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with session.Session(server.target) as sess:
|
||||
# Use the iterator without re-initializing in the second session.
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testNotInitializedError(self):
|
||||
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
|
||||
iterator = (dataset_ops.Dataset.from_tensors(components)
|
||||
.make_initializable_iterator())
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.FailedPreconditionError,
|
||||
"iterator has not been initialized"):
|
||||
sess.run(get_next)
|
||||
|
||||
def testReinitializableIterator(self):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensors(
|
||||
constant_op.constant([1, 2, 3]))
|
||||
dataset_4 = dataset_ops.Dataset.from_tensors(
|
||||
constant_op.constant([4, 5, 6, 7]))
|
||||
iterator = dataset_ops.Iterator.from_structure(dataset_3.output_types,
|
||||
[None])
|
||||
|
||||
dataset_3_init_op = iterator.make_initializer(dataset_3)
|
||||
dataset_4_init_op = iterator.make_initializer(dataset_4)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual(dataset_3.output_types, iterator.output_types)
|
||||
self.assertEqual(dataset_4.output_types, iterator.output_types)
|
||||
self.assertEqual([None], iterator.output_shapes.as_list())
|
||||
|
||||
with self.test_session() as sess:
|
||||
# The iterator is initially uninitialized.
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with one dataset.
|
||||
sess.run(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Initialize with a different dataset.
|
||||
sess.run(dataset_4_init_op)
|
||||
self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Reinitialize with the first dataset.
|
||||
sess.run(dataset_3_init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testReinitializableIteratorStaticErrors(self):
|
||||
# Non-matching structure for types and shapes.
|
||||
with self.assertRaises(TypeError):
|
||||
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64), [None])
|
||||
|
||||
# Test validation of dataset argument.
|
||||
iterator = dataset_ops.Iterator.from_structure((dtypes.int64,
|
||||
dtypes.float64))
|
||||
|
||||
# Incompatible structure.
|
||||
with self.assertRaises(ValueError):
|
||||
iterator.make_initializer(
|
||||
dataset_ops.Dataset.from_tensors(((constant_op.constant(
|
||||
[1, 2, 3], dtype=dtypes.int64),), (constant_op.constant(
|
||||
[4., 5., 6., 7.], dtype=dtypes.float64),))))
|
||||
|
||||
# Incompatible types.
|
||||
with self.assertRaises(TypeError):
|
||||
iterator.make_initializer(
|
||||
dataset_ops.Dataset.from_tensors((constant_op.constant(
|
||||
[1, 2, 3], dtype=dtypes.int32), constant_op.constant(
|
||||
[4., 5., 6., 7.], dtype=dtypes.float32))))
|
||||
|
||||
# Incompatible shapes.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
(dtypes.int64, dtypes.float64), ([None], []))
|
||||
with self.assertRaises(TypeError):
|
||||
iterator.make_initializer(
|
||||
dataset_ops.Dataset.from_tensors((constant_op.constant(
|
||||
[1, 2, 3], dtype=dtypes.int64), constant_op.constant(
|
||||
[4., 5., 6., 7.], dtype=dtypes.float64))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,330 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MapDatasetTest(test.TestCase):
|
||||
|
||||
def _buildMapDataset(self, components, count):
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
|
||||
.repeat(count))
|
||||
|
||||
def testMapDataset(self):
|
||||
"""Test an dataset that maps a TF function across its input elements."""
|
||||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
||||
# RepeatDataset(count).
|
||||
components = [np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7)]
|
||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
dataset = self._buildMapDataset(components, count)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Test single-threaded access to the iterator.
|
||||
sess.run(init_op, feed_dict={count: 14})
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test multi-threaded access to the same iterator.
|
||||
sess.run(init_op, feed_dict={count: 18})
|
||||
results = []
|
||||
def iterator_thread():
|
||||
while True:
|
||||
try:
|
||||
results.append(sess.run(get_next))
|
||||
except errors.OutOfRangeError:
|
||||
return
|
||||
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# `results` will contain the same elements components**2
|
||||
# repeated 18 times, but in a non-deterministic order. Sort the
|
||||
# results, and assert that each element of components**2 is
|
||||
# produced 18 times.
|
||||
results.sort(key=lambda x: x[0])
|
||||
for i in range(7):
|
||||
for j in range(18):
|
||||
for component, result_component in zip(components,
|
||||
results[i * 18 + j]):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
|
||||
def _buildParallelMapDataset(self, components, count, num_threads,
|
||||
output_buffer_size):
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
return (dataset_ops.Dataset.from_tensor_slices(components).map(
|
||||
_map_fn, num_threads=num_threads, output_buffer_size=output_buffer_size)
|
||||
.repeat(count))
|
||||
|
||||
def testParallelMapDataset(self):
|
||||
"""Test an dataset that maps a TF function across its input elements."""
|
||||
# The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) ->
|
||||
# RepeatDataset(count).
|
||||
components = [np.arange(7),
|
||||
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(7)]
|
||||
count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
num_threads = array_ops.placeholder(dtypes.int32, shape=[])
|
||||
output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
dataset = self._buildParallelMapDataset(components, count, num_threads,
|
||||
output_buffer_size)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
def do_test(num_threads_val, output_buffer_size_val):
|
||||
# Test single-threaded access to the iterator.
|
||||
sess.run(init_op, feed_dict={
|
||||
count: 14,
|
||||
num_threads: num_threads_val,
|
||||
output_buffer_size: output_buffer_size_val})
|
||||
for _ in range(14):
|
||||
for i in range(7):
|
||||
result = sess.run(get_next)
|
||||
for component, result_component in zip(components, result):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test multi-threaded access to the same iterator.
|
||||
sess.run(init_op, feed_dict={
|
||||
count: 18,
|
||||
num_threads: num_threads_val,
|
||||
output_buffer_size: output_buffer_size_val})
|
||||
results = []
|
||||
def iterator_thread():
|
||||
while True:
|
||||
try:
|
||||
results.append(sess.run(get_next))
|
||||
except errors.OutOfRangeError:
|
||||
return
|
||||
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# `results` will contain the same elements components**2
|
||||
# repeated 18 times, but in a non-deterministic order. Sort the
|
||||
# results, and assert that each element of components**2 is
|
||||
# produced 18 times.
|
||||
results.sort(key=lambda x: x[0])
|
||||
for i in range(7):
|
||||
for j in range(18):
|
||||
for component, result_component in zip(components,
|
||||
results[i * 18 + j]):
|
||||
self.assertAllEqual(component[i]**2, result_component)
|
||||
|
||||
for num_threads_val, output_buffer_size_val in [
|
||||
(1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]:
|
||||
do_test(num_threads_val, output_buffer_size_val)
|
||||
|
||||
def _testDisposeParallelMapDataset(self, explicit_dispose):
|
||||
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
|
||||
# RepeatDataset(1000).
|
||||
components = [np.arange(1000),
|
||||
np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
|
||||
np.array(37.0) * np.arange(1000)]
|
||||
|
||||
dataset = self._buildParallelMapDataset(components, 1000, 100, 100)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
if explicit_dispose:
|
||||
dispose_op = iterator.dispose_op()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
if explicit_dispose:
|
||||
sess.run(dispose_op)
|
||||
|
||||
def testExplicitDisposeParallelMapDataset(self):
|
||||
self._testDisposeParallelMapDataset(True)
|
||||
|
||||
def testImplicitDisposeParallelMapDataset(self):
|
||||
self._testDisposeParallelMapDataset(False)
|
||||
|
||||
def testParallelMapError(self):
|
||||
components = [np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)]
|
||||
|
||||
dataset = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.map(lambda x: array_ops.check_numerics(x, "message")))
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
for _ in range(3):
|
||||
sess.run(get_next)
|
||||
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(get_next)
|
||||
sess.run(get_next)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testCaptureHashTable(self):
|
||||
# NOTE(mrry): We must use the V2 variants of `HashTable`
|
||||
# etc. because these produce a `tf.resource`-typed output that is
|
||||
# compatible with the in-graph function implementation.
|
||||
default_val = -1
|
||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||
table = lookup_ops.HashTable(
|
||||
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
|
||||
|
||||
input_sentences = dataset_ops.Dataset.from_tensor_slices(
|
||||
constant_op.constant([
|
||||
"brain brain tank salad surgery",
|
||||
"surgery brain",
|
||||
]))
|
||||
|
||||
iterator = (input_sentences
|
||||
.map(lambda x: string_ops.string_split([x]).values)
|
||||
.map(table.lookup)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(table.init)
|
||||
sess.run(init_op)
|
||||
|
||||
print(sess.run(get_next))
|
||||
print(sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testCaptureQueue(self):
|
||||
elements = np.random.randint(100, size=[200])
|
||||
queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
|
||||
enqueue_op = queue.enqueue_many(elements)
|
||||
close_op = queue.close()
|
||||
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1)
|
||||
.map(lambda _: queue.dequeue()).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(enqueue_op)
|
||||
sess.run(close_op)
|
||||
sess.run(init_op)
|
||||
for element in elements:
|
||||
self.assertEqual(element, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testCaptureVariable(self):
|
||||
counter_var = variable_scope.get_variable(
|
||||
"counter", (), dtypes.int32, use_resource=True)
|
||||
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
|
||||
.map(lambda _: counter_var.assign_add(1))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(counter_var.initializer)
|
||||
sess.run(init_op)
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(counter_var))
|
||||
self.assertEqual(i + 1, sess.run(get_next))
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertEqual(10, sess.run(counter_var))
|
||||
|
||||
def testCaptureUninitializedVariableError(self):
|
||||
counter_var = variable_scope.get_variable(
|
||||
"counter", (), dtypes.int32, use_resource=True)
|
||||
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
|
||||
.map(lambda _: counter_var.assign_add(1))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.FailedPreconditionError,
|
||||
"Failed to capture resource"):
|
||||
sess.run(init_op)
|
||||
|
||||
def testSeededStatefulOperatorIsProperlyStateful(self):
|
||||
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
|
||||
.map(lambda _: random_ops.random_uniform((), seed=11)).batch(2)
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
random_values = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
random_values.extend(sess.run(get_next))
|
||||
self.assertEqual(10, len(random_values))
|
||||
self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
|
||||
sess.run(init_op)
|
||||
random_values_2 = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
random_values_2.extend(sess.run(get_next))
|
||||
|
||||
# Randomness is repeatable given same seed
|
||||
self.assertAllClose(random_values, random_values_2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,182 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test RangeDataset."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RangeDatasetTest(test.TestCase):
|
||||
|
||||
def testStop(self):
|
||||
stop = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op, feed_dict={stop: 5})
|
||||
for i in range(5):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testStartStop(self):
|
||||
start = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stop = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = dataset_ops.Dataset.range(start,
|
||||
stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op, feed_dict={start: 2, stop: 5})
|
||||
for i in range(2, 5):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testStartStopStep(self):
|
||||
start = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stop = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
step = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = dataset_ops.Dataset.range(start, stop,
|
||||
step).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2})
|
||||
for i in range(2, 10, 2):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testZeroStep(self):
|
||||
start = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stop = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
step = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = dataset_ops.Dataset.range(start, stop,
|
||||
step).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0})
|
||||
|
||||
def testNegativeStep(self):
|
||||
start = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stop = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
step = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = dataset_ops.Dataset.range(start, stop,
|
||||
step).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1})
|
||||
# This for loop is a no-op but will ensure that the implementation is
|
||||
# consistent with range if it ever changes.
|
||||
for i in range(2, 10, -1):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testStopLessThanStart(self):
|
||||
start = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stop = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = dataset_ops.Dataset.range(start,
|
||||
stop).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op, feed_dict={start: 10, stop: 2})
|
||||
# This for loop is a no-op but will ensure that the implementation is
|
||||
# consistent with range if it ever changes.
|
||||
for i in range(10, 2):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testStopLessThanStartWithPositiveStep(self):
|
||||
start = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stop = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
step = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = dataset_ops.Dataset.range(start, stop,
|
||||
step).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2})
|
||||
# This for loop is a no-op but will ensure that the implementation is
|
||||
# consistent with range if it ever changes.
|
||||
for i in range(10, 2, 2):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testStopLessThanStartWithNegativeStep(self):
|
||||
start = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
stop = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
step = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = dataset_ops.Dataset.range(start, stop,
|
||||
step).make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1})
|
||||
for i in range(10, 2, -1):
|
||||
self.assertEqual(i, sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testEnumerateDataset(self):
|
||||
components = [np.array(["a", "b"]), np.array([1, 2]), np.array([37.0, 38])]
|
||||
start = constant_op.constant(20, dtype=dtypes.int64)
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components).enumerate(
|
||||
start=start).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual(dtypes.int64, get_next[0].dtype)
|
||||
self.assertEqual((), get_next[0].shape)
|
||||
self.assertEqual([tensor_shape.TensorShape([])] * 3,
|
||||
[t.shape for t in get_next[1]])
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertEqual((20, [b"a", 1, 37.0]), sess.run(get_next))
|
||||
self.assertEqual((21, [b"b", 2, 38.0]), sess.run(get_next))
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,500 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class TextLineDatasetTest(test.TestCase):
|
||||
|
||||
def _lineText(self, f, l):
|
||||
return compat.as_bytes("%d: %d" % (f, l))
|
||||
|
||||
def _createFiles(self, num_files, num_lines, crlf=False):
|
||||
filenames = []
|
||||
for i in range(num_files):
|
||||
fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
|
||||
filenames.append(fn)
|
||||
with open(fn, "wb") as f:
|
||||
for j in range(num_lines):
|
||||
f.write(self._lineText(i, j))
|
||||
# Always include a newline after the record unless it is
|
||||
# at the end of the file, in which case we include it sometimes.
|
||||
if j + 1 != num_lines or i == 0:
|
||||
f.write(b"\r\n" if crlf else b"\n")
|
||||
return filenames
|
||||
|
||||
def testTextLineDataset(self):
|
||||
test_filenames = self._createFiles(2, 5, crlf=True)
|
||||
filenames = array_ops.placeholder(dtypes.string, shape=[None])
|
||||
num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
repeat_dataset = dataset_ops.TextLineDataset(filenames).repeat(num_epochs)
|
||||
batch_dataset = repeat_dataset.batch(batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
init_op = iterator.make_initializer(repeat_dataset)
|
||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Basic test: read from file 0.
|
||||
sess.run(init_op, feed_dict={filenames: [test_filenames[0]],
|
||||
num_epochs: 1})
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(0, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Basic test: read from file 1.
|
||||
sess.run(init_op, feed_dict={filenames: [test_filenames[1]],
|
||||
num_epochs: 1})
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(1, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Basic test: read from both files.
|
||||
sess.run(init_op, feed_dict={filenames: test_filenames,
|
||||
num_epochs: 1})
|
||||
for j in range(2):
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(j, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test repeated iteration through both files.
|
||||
sess.run(init_op, feed_dict={filenames: test_filenames,
|
||||
num_epochs: 10})
|
||||
for _ in range(10):
|
||||
for j in range(2):
|
||||
for i in range(5):
|
||||
self.assertEqual(self._lineText(j, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test batched and repeated iteration through both files.
|
||||
sess.run(init_batch_op, feed_dict={filenames: test_filenames,
|
||||
num_epochs: 10,
|
||||
batch_size: 5})
|
||||
for _ in range(10):
|
||||
self.assertAllEqual([self._lineText(0, i) for i in range(5)],
|
||||
sess.run(get_next))
|
||||
self.assertAllEqual([self._lineText(1, i) for i in range(5)],
|
||||
sess.run(get_next))
|
||||
|
||||
|
||||
class FixedLengthRecordReaderTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(FixedLengthRecordReaderTest, self).setUp()
|
||||
self._num_files = 2
|
||||
self._num_records = 7
|
||||
self._header_bytes = 5
|
||||
self._record_bytes = 3
|
||||
self._footer_bytes = 2
|
||||
|
||||
def _record(self, f, r):
|
||||
return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
|
||||
|
||||
def _createFiles(self):
|
||||
filenames = []
|
||||
for i in range(self._num_files):
|
||||
fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
|
||||
filenames.append(fn)
|
||||
with open(fn, "wb") as f:
|
||||
f.write(b"H" * self._header_bytes)
|
||||
for j in range(self._num_records):
|
||||
f.write(self._record(i, j))
|
||||
f.write(b"F" * self._footer_bytes)
|
||||
return filenames
|
||||
|
||||
def testFixedLengthRecordDataset(self):
|
||||
test_filenames = self._createFiles()
|
||||
filenames = array_ops.placeholder(dtypes.string, shape=[None])
|
||||
num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
repeat_dataset = (dataset_ops.FixedLengthRecordDataset(
|
||||
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
|
||||
.repeat(num_epochs))
|
||||
batch_dataset = repeat_dataset.batch(batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
init_op = iterator.make_initializer(repeat_dataset)
|
||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Basic test: read from file 0.
|
||||
sess.run(init_op, feed_dict={filenames: [test_filenames[0]],
|
||||
num_epochs: 1})
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(0, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Basic test: read from file 1.
|
||||
sess.run(init_op, feed_dict={filenames: [test_filenames[1]],
|
||||
num_epochs: 1})
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(1, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Basic test: read from both files.
|
||||
sess.run(init_op, feed_dict={filenames: test_filenames,
|
||||
num_epochs: 1})
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(j, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test repeated iteration through both files.
|
||||
sess.run(init_op, feed_dict={filenames: test_filenames,
|
||||
num_epochs: 10})
|
||||
for _ in range(10):
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertEqual(self._record(j, i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test batched and repeated iteration through both files.
|
||||
sess.run(init_batch_op, feed_dict={filenames: test_filenames,
|
||||
num_epochs: 10,
|
||||
batch_size: self._num_records})
|
||||
for _ in range(10):
|
||||
for j in range(self._num_files):
|
||||
self.assertAllEqual([self._record(j, i)
|
||||
for i in range(self._num_records)],
|
||||
sess.run(get_next))
|
||||
|
||||
|
||||
class TFRecordDatasetTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TFRecordDatasetTest, self).setUp()
|
||||
self._num_files = 2
|
||||
self._num_records = 7
|
||||
|
||||
self.test_filenames = self._createFiles()
|
||||
|
||||
self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
|
||||
self.num_epochs = array_ops.placeholder_with_default(
|
||||
constant_op.constant(1, dtypes.int64), shape=[])
|
||||
self.compression_type = array_ops.placeholder_with_default("", shape=[])
|
||||
self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
repeat_dataset = dataset_ops.TFRecordDataset(
|
||||
self.filenames, self.compression_type).repeat(self.num_epochs)
|
||||
batch_dataset = repeat_dataset.batch(self.batch_size)
|
||||
|
||||
iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types)
|
||||
self.init_op = iterator.make_initializer(repeat_dataset)
|
||||
self.init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
self.get_next = iterator.get_next()
|
||||
|
||||
def _record(self, f, r):
|
||||
return compat.as_bytes("Record %d of file %d" % (r, f))
|
||||
|
||||
def _createFiles(self):
|
||||
filenames = []
|
||||
for i in range(self._num_files):
|
||||
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
|
||||
filenames.append(fn)
|
||||
writer = python_io.TFRecordWriter(fn)
|
||||
for j in range(self._num_records):
|
||||
writer.write(self._record(i, j))
|
||||
writer.close()
|
||||
return filenames
|
||||
|
||||
def testReadOneEpoch(self):
|
||||
with self.test_session() as sess:
|
||||
# Basic test: read from file 0.
|
||||
sess.run(self.init_op,
|
||||
feed_dict={self.filenames: [self.test_filenames[0]],
|
||||
self.num_epochs: 1})
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(0, i), sess.run(self.get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.get_next)
|
||||
|
||||
# Basic test: read from file 1.
|
||||
sess.run(self.init_op,
|
||||
feed_dict={self.filenames: [self.test_filenames[1]],
|
||||
self.num_epochs: 1})
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(1, i), sess.run(self.get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.get_next)
|
||||
|
||||
# Basic test: read from both files.
|
||||
sess.run(self.init_op,
|
||||
feed_dict={self.filenames: self.test_filenames,
|
||||
self.num_epochs: 1})
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.get_next)
|
||||
|
||||
def testReadTenEpochs(self):
|
||||
with self.test_session() as sess:
|
||||
sess.run(self.init_op, feed_dict={self.filenames: self.test_filenames,
|
||||
self.num_epochs: 10})
|
||||
for _ in range(10):
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.get_next)
|
||||
|
||||
def testReadTenEpochsOfBatches(self):
|
||||
with self.test_session() as sess:
|
||||
sess.run(self.init_batch_op,
|
||||
feed_dict={self.filenames: self.test_filenames,
|
||||
self.num_epochs: 10,
|
||||
self.batch_size: self._num_records})
|
||||
for _ in range(10):
|
||||
for j in range(self._num_files):
|
||||
values = sess.run(self.get_next)
|
||||
self.assertAllEqual([self._record(j, i)
|
||||
for i in range(self._num_records)], values)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.get_next)
|
||||
|
||||
def testReadZlibFiles(self):
|
||||
zlib_files = []
|
||||
for i, fn in enumerate(self.test_filenames):
|
||||
with open(fn, "rb") as f:
|
||||
cdata = zlib.compress(f.read())
|
||||
|
||||
zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
|
||||
with open(zfn, "wb") as f:
|
||||
f.write(cdata)
|
||||
zlib_files.append(zfn)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(self.init_op,
|
||||
feed_dict={self.filenames: zlib_files,
|
||||
self.compression_type: "ZLIB"})
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.get_next)
|
||||
|
||||
def testReadGzipFiles(self):
|
||||
gzip_files = []
|
||||
for i, fn in enumerate(self.test_filenames):
|
||||
with open(fn, "rb") as f:
|
||||
gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
|
||||
with gzip.GzipFile(gzfn, "wb") as gzf:
|
||||
gzf.write(f.read())
|
||||
gzip_files.append(gzfn)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(self.init_op,
|
||||
feed_dict={self.filenames: gzip_files,
|
||||
self.compression_type: "GZIP"})
|
||||
for j in range(self._num_files):
|
||||
for i in range(self._num_records):
|
||||
self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(self.get_next)
|
||||
|
||||
|
||||
class ReadBatchFeaturesTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ReadBatchFeaturesTest, self).setUp()
|
||||
self._num_files = 2
|
||||
self._num_records = 7
|
||||
self.test_filenames = self._createFiles()
|
||||
|
||||
def _read_batch_features(self, filenames, num_epochs, batch_size):
|
||||
self.filenames = filenames
|
||||
self.num_epochs = num_epochs
|
||||
self.batch_size = batch_size
|
||||
|
||||
return dataset_ops.read_batch_features(
|
||||
file_pattern=self.filenames,
|
||||
batch_size=self.batch_size,
|
||||
features={
|
||||
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
|
||||
"keywords": parsing_ops.VarLenFeature(dtypes.string)
|
||||
},
|
||||
reader=dataset_ops.TFRecordDataset,
|
||||
randomize_input=False,
|
||||
num_epochs=self.num_epochs)
|
||||
|
||||
def _record(self, f, r):
|
||||
example = example_pb2.Example(features=feature_pb2.Features(
|
||||
feature={
|
||||
"file":
|
||||
feature_pb2.Feature(int64_list=feature_pb2.Int64List(
|
||||
value=[f])),
|
||||
"record":
|
||||
feature_pb2.Feature(int64_list=feature_pb2.Int64List(
|
||||
value=[r])),
|
||||
"keywords":
|
||||
feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
|
||||
value=self._get_keywords(f, r)))
|
||||
}))
|
||||
return example.SerializeToString()
|
||||
|
||||
def _get_keywords(self, f, r):
|
||||
num_keywords = 1 + (f + r) % 2
|
||||
keywords = []
|
||||
for index in range(num_keywords):
|
||||
keywords.append(compat.as_bytes("keyword%d" % index))
|
||||
return keywords
|
||||
|
||||
def _createFiles(self):
|
||||
filenames = []
|
||||
for i in range(self._num_files):
|
||||
fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
|
||||
filenames.append(fn)
|
||||
writer = python_io.TFRecordWriter(fn)
|
||||
for j in range(self._num_records):
|
||||
writer.write(self._record(i, j))
|
||||
writer.close()
|
||||
return filenames
|
||||
|
||||
def _next_actual_batch(self, sess):
|
||||
file_op = self.outputs["file"]
|
||||
keywords_indices_op = self.outputs["keywords"].indices
|
||||
keywords_values_op = self.outputs["keywords"].values
|
||||
keywords_dense_shape_op = self.outputs["keywords"].dense_shape
|
||||
record_op = self.outputs["record"]
|
||||
return sess.run([
|
||||
file_op, keywords_indices_op, keywords_values_op,
|
||||
keywords_dense_shape_op, record_op
|
||||
])
|
||||
|
||||
def _next_expected_batch(self, file_indices, batch_size, num_epochs):
|
||||
|
||||
def _next_record(file_indices):
|
||||
for j in file_indices:
|
||||
for i in range(self._num_records):
|
||||
yield j, i
|
||||
|
||||
file_batch = []
|
||||
keywords_batch_indices = []
|
||||
keywords_batch_values = []
|
||||
keywords_batch_max_len = 0
|
||||
record_batch = []
|
||||
batch_index = 0
|
||||
for _ in range(num_epochs):
|
||||
for record in _next_record(file_indices):
|
||||
f = record[0]
|
||||
r = record[1]
|
||||
file_batch.append(f)
|
||||
record_batch.append(r)
|
||||
keywords = self._get_keywords(f, r)
|
||||
keywords_batch_values.extend(keywords)
|
||||
keywords_batch_indices.extend([[batch_index, i]
|
||||
for i in range(len(keywords))])
|
||||
batch_index += 1
|
||||
keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
|
||||
if len(file_batch) == batch_size:
|
||||
yield [
|
||||
file_batch, keywords_batch_indices, keywords_batch_values,
|
||||
[batch_size, keywords_batch_max_len], record_batch
|
||||
]
|
||||
file_batch = []
|
||||
keywords_batch_indices = []
|
||||
keywords_batch_values = []
|
||||
keywords_batch_max_len = 0
|
||||
record_batch = []
|
||||
batch_index = 0
|
||||
if file_batch:
|
||||
yield [
|
||||
file_batch, keywords_batch_indices, keywords_batch_values,
|
||||
[len(file_batch), keywords_batch_max_len], record_batch
|
||||
]
|
||||
|
||||
def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1):
|
||||
if file_index is not None:
|
||||
file_indices = [file_index]
|
||||
else:
|
||||
file_indices = range(self._num_files)
|
||||
|
||||
for expected_batch in self._next_expected_batch(file_indices, batch_size,
|
||||
num_epochs):
|
||||
actual_batch = self._next_actual_batch(sess)
|
||||
for i in range(len(expected_batch)):
|
||||
self.assertAllEqual(expected_batch[i], actual_batch[i])
|
||||
|
||||
def testRead(self):
|
||||
for batch_size in [1, 2]:
|
||||
for num_epochs in [1, 10]:
|
||||
with ops.Graph().as_default():
|
||||
with self.test_session(graph=ops.get_default_graph()) as sess:
|
||||
# Basic test: read from file 0.
|
||||
self.outputs = self._read_batch_features(
|
||||
filenames=self.test_filenames[0],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size)
|
||||
self._verify_records(sess, batch_size, 0, num_epochs=num_epochs)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with self.test_session(graph=ops.get_default_graph()) as sess:
|
||||
# Basic test: read from file 1.
|
||||
self.outputs = self._read_batch_features(
|
||||
filenames=self.test_filenames[1],
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size)
|
||||
self._verify_records(sess, batch_size, 1, num_epochs=num_epochs)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with self.test_session(graph=ops.get_default_graph()) as sess:
|
||||
# Basic test: read from both files.
|
||||
self.outputs = self._read_batch_features(
|
||||
filenames=self.test_filenames,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size)
|
||||
self._verify_records(sess, batch_size, num_epochs=num_epochs)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self._next_actual_batch(sess)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
79
tensorflow/contrib/data/python/kernel_tests/resample_test.py
Normal file
79
tensorflow/contrib/data/python/kernel_tests/resample_test.py
Normal file
@ -0,0 +1,79 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class ResampleTest(test.TestCase):
|
||||
|
||||
def testInitialKnownDistribution(self):
|
||||
self._testDistribution(initial_known=True)
|
||||
|
||||
def testInitialNotKnownDistribution(self):
|
||||
self._testDistribution(initial_known=False)
|
||||
|
||||
def _testDistribution(self, initial_known):
|
||||
classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
|
||||
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
|
||||
initial_dist = [0.2] * 5 if initial_known else None
|
||||
iterator = dataset_ops.Iterator.from_dataset(
|
||||
dataset_ops.rejection_resample(
|
||||
(dataset_ops.Dataset.from_tensor_slices(classes)
|
||||
.shuffle(200, seed=21)
|
||||
.map(lambda c: (c, string_ops.as_string(c)))),
|
||||
target_dist=target_dist,
|
||||
initial_dist=initial_dist,
|
||||
class_func=lambda c, _: c,
|
||||
seed=27))
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
variable_init_op = variables.global_variables_initializer()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(variable_init_op)
|
||||
sess.run(init_op)
|
||||
returned = []
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
while True:
|
||||
returned.append(sess.run(get_next))
|
||||
|
||||
returned_classes, returned_classes_and_data = zip(*returned)
|
||||
_, returned_data = zip(*returned_classes_and_data)
|
||||
self.assertAllEqual([compat.as_bytes(str(c))
|
||||
for c in returned_classes], returned_data)
|
||||
total_returned = len(returned_classes)
|
||||
# Subsampling rejects a large precentage of the initial data in
|
||||
# this case.
|
||||
self.assertGreater(total_returned, 20000 * 0.2)
|
||||
class_counts = np.array([
|
||||
len([True for v in returned_classes if v == c])
|
||||
for c in range(5)])
|
||||
returned_dist = class_counts / total_returned
|
||||
self.assertAllClose(target_dist, returned_dist, atol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,211 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SequenceDatasetTest(test.TestCase):
|
||||
|
||||
def testRepeatTensorDataset(self):
|
||||
"""Test a dataset that repeats its input multiple times."""
|
||||
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
|
||||
# This placeholder can be fed when dataset-definition subgraph
|
||||
# runs (i.e. `init_op` below) to configure the number of
|
||||
# repetitions used in a particular iterator.
|
||||
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensors(components)
|
||||
.repeat(count_placeholder).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Test a finite repetition.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 3})
|
||||
for _ in range(3):
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test a different finite repetition.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 7})
|
||||
for _ in range(7):
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test an empty repetition.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 0})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Test an infinite repetition.
|
||||
# NOTE(mrry): There's not a good way to test that the sequence
|
||||
# actually is infinite.
|
||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||
for _ in range(17):
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
|
||||
def testTakeTensorDataset(self):
|
||||
components = [np.arange(10)]
|
||||
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.take(count_placeholder).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Take fewer than input size
|
||||
sess.run(init_op, feed_dict={count_placeholder: 4})
|
||||
for i in range(4):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Take more than input size
|
||||
sess.run(init_op, feed_dict={count_placeholder: 25})
|
||||
for i in range(10):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Take all of input
|
||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||
for i in range(10):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Take nothing
|
||||
sess.run(init_op, feed_dict={count_placeholder: 0})
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testSkipTensorDataset(self):
|
||||
components = [np.arange(10)]
|
||||
count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.skip(count_placeholder).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
# Skip fewer than input size, we should skip
|
||||
# the first 4 elements and then read the rest.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 4})
|
||||
for i in range(4, 10):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Skip more than input size: get nothing.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 25})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Skip exactly input size.
|
||||
sess.run(init_op, feed_dict={count_placeholder: 10})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Set -1 for 'count': skip the entire dataset.
|
||||
sess.run(init_op, feed_dict={count_placeholder: -1})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Skip nothing
|
||||
sess.run(init_op, feed_dict={count_placeholder: 0})
|
||||
for i in range(0, 10):
|
||||
results = sess.run(get_next)
|
||||
self.assertAllEqual(results, components[0][i:i+1])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testRepeatRepeatTensorDataset(self):
|
||||
"""Test the composition of repeat datasets."""
|
||||
components = [np.array(1), np.array([1, 2, 3]), np.array(37.0)]
|
||||
inner_count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
outer_count = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_tensors(components).repeat(inner_count)
|
||||
.repeat(outer_count).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([c.shape for c in components],
|
||||
[t.shape for t in get_next])
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
|
||||
for _ in range(7 * 14):
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(components, results):
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testRepeatEmptyDataset(self):
|
||||
"""Test that repeating an empty dataset does not hang."""
|
||||
iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10)
|
||||
.repeat(-1).make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
with self.assertRaisesRegexp(
|
||||
errors.OutOfRangeError,
|
||||
"Attempted to repeat an empty dataset infinitely."):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,152 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ShuffleDatasetTest(test.TestCase):
|
||||
|
||||
def testShuffleDataset(self):
|
||||
components = [
|
||||
np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
|
||||
np.array([9.0, 10.0, 11.0, 12.0])
|
||||
]
|
||||
count_placeholder = array_ops.placeholder_with_default(
|
||||
constant_op.constant(5, dtypes.int64), shape=[])
|
||||
buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
|
||||
.repeat(count_placeholder))
|
||||
|
||||
shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
|
||||
seed_placeholder)
|
||||
|
||||
self.assertEqual([c.shape[1:] for c in components],
|
||||
shuffle_dataset.output_shapes)
|
||||
|
||||
# Create initialization ops for iterators without and with
|
||||
# shuffling, respectively.
|
||||
iterator = dataset_ops.Iterator.from_structure(
|
||||
shuffle_dataset.output_types, shuffle_dataset.output_shapes)
|
||||
init_fifo_op = iterator.make_initializer(repeat_dataset)
|
||||
init_shuffle_op = iterator.make_initializer(shuffle_dataset)
|
||||
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
# First run without shuffling to collect the "ground truth".
|
||||
sess.run(init_fifo_op)
|
||||
unshuffled_elements = []
|
||||
for _ in range(20):
|
||||
unshuffled_elements.append(sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
# Assert that the shuffled dataset has the same elements as the
|
||||
# "ground truth".
|
||||
sess.run(
|
||||
init_shuffle_op,
|
||||
feed_dict={buffer_size_placeholder: 100,
|
||||
seed_placeholder: 37})
|
||||
shuffled_elements = []
|
||||
for _ in range(20):
|
||||
shuffled_elements.append(sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual(
|
||||
sorted(unshuffled_elements), sorted(shuffled_elements))
|
||||
|
||||
# Assert that shuffling twice with the same seeds gives the same sequence.
|
||||
sess.run(
|
||||
init_shuffle_op,
|
||||
feed_dict={buffer_size_placeholder: 100,
|
||||
seed_placeholder: 37})
|
||||
reshuffled_elements_same_seed = []
|
||||
for _ in range(20):
|
||||
reshuffled_elements_same_seed.append(sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)
|
||||
|
||||
# Assert that shuffling twice with a different seed gives a different
|
||||
# permutation of the same elements.
|
||||
sess.run(
|
||||
init_shuffle_op,
|
||||
feed_dict={buffer_size_placeholder: 100,
|
||||
seed_placeholder: 1037})
|
||||
reshuffled_elements_different_seed = []
|
||||
for _ in range(20):
|
||||
reshuffled_elements_different_seed.append(sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
|
||||
self.assertAllEqual(
|
||||
sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))
|
||||
|
||||
# Assert that the shuffled dataset has the same elements as the
|
||||
# "ground truth" when the buffer size is smaller than the input
|
||||
# dataset.
|
||||
sess.run(
|
||||
init_shuffle_op,
|
||||
feed_dict={buffer_size_placeholder: 2,
|
||||
seed_placeholder: 37})
|
||||
reshuffled_elements_small_buffer = []
|
||||
for _ in range(20):
|
||||
reshuffled_elements_small_buffer.append(sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual(
|
||||
sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))
|
||||
|
||||
# Test the case of shuffling an empty dataset.
|
||||
sess.run(init_shuffle_op, feed_dict={buffer_size_placeholder: 2,
|
||||
seed_placeholder: 37,
|
||||
count_placeholder: 0})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testDefaultArguments(self):
|
||||
components = np.array([0, 1, 2, 3, 4])
|
||||
iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
|
||||
.repeat().make_one_shot_iterator())
|
||||
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
counts = collections.defaultdict(lambda: 0)
|
||||
for _ in range(10):
|
||||
for _ in range(5):
|
||||
counts[sess.run(get_next)] += 1
|
||||
|
||||
for i in range(5):
|
||||
self.assertEqual(10, counts[i])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,114 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ZipDatasetTest(test.TestCase):
|
||||
|
||||
def testZipDataset(self):
|
||||
component_placeholders = [
|
||||
array_ops.placeholder(dtypes.int64),
|
||||
array_ops.placeholder(dtypes.int64),
|
||||
array_ops.placeholder(dtypes.float64)
|
||||
]
|
||||
|
||||
datasets = [
|
||||
dataset_ops.Dataset.from_tensor_slices(component_placeholder)
|
||||
for component_placeholder in component_placeholders
|
||||
]
|
||||
zipped = dataset_ops.Dataset.zip(datasets)
|
||||
|
||||
iterator = zipped.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
equal_length_components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array([37.0, 38.0, 39.0, 40.0])
|
||||
]
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, equal_length_components)})
|
||||
for i in range(4):
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(
|
||||
equal_length_components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, variable_length_components)})
|
||||
for i in range(2):
|
||||
results = sess.run(get_next)
|
||||
for component, result_component in zip(
|
||||
variable_length_components, results):
|
||||
self.assertAllEqual(component[i], result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testNestedZipDataset(self):
|
||||
component_placeholders = [
|
||||
array_ops.placeholder(dtypes.int64, shape=[4, 20]),
|
||||
array_ops.placeholder(dtypes.int64, shape=[4, 22]),
|
||||
array_ops.placeholder(dtypes.float64, shape=[4])
|
||||
]
|
||||
|
||||
datasets = [
|
||||
dataset_ops.Dataset.from_tensor_slices(component_placeholder)
|
||||
for component_placeholder in component_placeholders
|
||||
]
|
||||
zipped = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
|
||||
|
||||
iterator = zipped.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
self.assertEqual([20], get_next[0].shape)
|
||||
self.assertEqual([22], get_next[1][0].shape)
|
||||
self.assertEqual([], get_next[1][1].shape)
|
||||
|
||||
with self.test_session() as sess:
|
||||
equal_length_components = [
|
||||
np.tile(np.array([[1], [2], [3], [4]]), 20),
|
||||
np.tile(np.array([[12], [13], [14], [15]]), 22),
|
||||
np.array([37.0, 38.0, 39.0, 40.0])
|
||||
]
|
||||
sess.run(init_op, feed_dict={ph: value for ph, value in zip(
|
||||
component_placeholders, equal_length_components)})
|
||||
for i in range(4):
|
||||
result1, (result2, result3) = sess.run(get_next)
|
||||
self.assertAllEqual(equal_length_components[0][i], result1)
|
||||
self.assertAllEqual(equal_length_components[1][i], result2)
|
||||
self.assertAllEqual(equal_length_components[2][i], result3)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
31
tensorflow/contrib/data/python/ops/BUILD
Normal file
31
tensorflow/contrib/data/python/ops/BUILD
Normal file
@ -0,0 +1,31 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "dataset_ops",
|
||||
srcs = ["dataset_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/framework:function",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
1902
tensorflow/contrib/data/python/ops/dataset_ops.py
Normal file
1902
tensorflow/contrib/data/python/ops/dataset_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -273,7 +273,6 @@ class Layer(tf_base_layers.Layer):
|
||||
# Internal methods:
|
||||
build(input_shape)
|
||||
_add_inbound_node(layer, index=0)
|
||||
assert_input_compatibility()
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -381,97 +380,6 @@ class Layer(tf_base_layers.Layer):
|
||||
self.constraints[weight] = constraint
|
||||
return weight
|
||||
|
||||
def assert_input_compatibility(self, inputs):
|
||||
"""Checks compatibility between the layer and provided inputs.
|
||||
|
||||
This checks that the tensor(s) `input`
|
||||
verify the input assumptions of the layer
|
||||
(if any). If not, exceptions are raised.
|
||||
|
||||
Arguments:
|
||||
inputs: input tensor or list of input tensors.
|
||||
|
||||
Raises:
|
||||
ValueError: in case of mismatch between
|
||||
the provided inputs and the expectations of the layer.
|
||||
"""
|
||||
if not self.input_spec:
|
||||
return
|
||||
if not isinstance(self.input_spec, (list, tuple)):
|
||||
input_spec = _to_list(self.input_spec)
|
||||
else:
|
||||
input_spec = self.input_spec
|
||||
inputs = _to_list(inputs)
|
||||
if len(inputs) != len(input_spec):
|
||||
raise ValueError('Layer ' + self.name + ' expects ' +
|
||||
str(len(input_spec)) + ' inputs, '
|
||||
'but it received ' + str(len(inputs)) +
|
||||
' input tensors. Input received: ' + str(inputs))
|
||||
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
|
||||
if spec is None:
|
||||
continue
|
||||
|
||||
# Check ndim.
|
||||
if spec.ndim is not None:
|
||||
if K.ndim(x) != spec.ndim:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + self.name +
|
||||
': expected ndim=' + str(spec.ndim) + ', found ndim='
|
||||
+ str(K.ndim(x)))
|
||||
if spec.max_ndim is not None:
|
||||
ndim = K.ndim(x)
|
||||
if ndim is not None and ndim > spec.max_ndim:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + self.name +
|
||||
': expected max_ndim=' + str(spec.max_ndim) +
|
||||
', found ndim=' + str(K.ndim(x)))
|
||||
if spec.min_ndim is not None:
|
||||
ndim = K.ndim(x)
|
||||
if ndim is not None and ndim < spec.min_ndim:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + self.name +
|
||||
': expected min_ndim=' + str(spec.min_ndim) +
|
||||
', found ndim=' + str(K.ndim(x)))
|
||||
# Check dtype.
|
||||
if spec.dtype is not None:
|
||||
if K.dtype(x) != spec.dtype:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + self.name +
|
||||
': expected dtype=' + str(spec.dtype) +
|
||||
', found dtype=' + str(K.dtype(x)))
|
||||
# Check specific shape axes.
|
||||
if spec.axes:
|
||||
try:
|
||||
x_shape = K.int_shape(x)
|
||||
except TypeError:
|
||||
x_shape = None
|
||||
if x_shape is not None:
|
||||
for axis, value in spec.axes.items():
|
||||
if hasattr(value, 'value'):
|
||||
value = value.value
|
||||
if value is not None and x_shape[int(axis)] not in {value, None}:
|
||||
raise ValueError(
|
||||
'Input ' + str(input_index) + ' is incompatible with layer ' +
|
||||
self.name + ': expected axis ' + str(axis) +
|
||||
' of input shape to have '
|
||||
'value ' + str(value) + ' but got shape ' + str(x_shape))
|
||||
# Check shape.
|
||||
if spec.shape is not None:
|
||||
try:
|
||||
x_shape = K.int_shape(x)
|
||||
except TypeError:
|
||||
x_shape = None
|
||||
if x_shape is not None:
|
||||
for spec_dim, dim in zip(spec.shape, x_shape):
|
||||
if hasattr(spec_dim, 'value'):
|
||||
spec_dim = spec_dim.value
|
||||
if spec_dim is not None and dim is not None:
|
||||
if spec_dim != dim:
|
||||
raise ValueError('Input ' + str(input_index) +
|
||||
' is incompatible with layer ' + self.name +
|
||||
': expected shape=' + str(spec.shape) +
|
||||
', found shape=' + str(x_shape))
|
||||
|
||||
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
|
||||
"""This is where the layer's logic lives.
|
||||
|
||||
@ -509,11 +417,6 @@ class Layer(tf_base_layers.Layer):
|
||||
if isinstance(inputs, list):
|
||||
inputs = inputs[:]
|
||||
|
||||
# Raise exceptions in case the input is not compatible
|
||||
# with the input_spec set at build time.
|
||||
# TODO(fchollet): call after the layer is built, too.
|
||||
self.assert_input_compatibility(inputs)
|
||||
|
||||
# Handle mask propagation.
|
||||
previous_mask = _collect_previous_mask(inputs)
|
||||
user_kwargs = copy.copy(kwargs)
|
||||
|
@ -37,233 +37,11 @@ from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling3D
|
||||
# pylint: enable=unused-import
|
||||
from tensorflow.contrib.keras.python.keras.utils import conv_utils
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import convolutional as tf_convolutional_layers
|
||||
|
||||
|
||||
class _Conv(Layer):
|
||||
"""Abstract nD convolution layer (private, used as implementation base).
|
||||
|
||||
This layer creates a convolution kernel that is convolved
|
||||
with the layer input to produce a tensor of outputs.
|
||||
If `use_bias` is True, a bias vector is created and added to the outputs.
|
||||
Finally, if `activation` is not `None`,
|
||||
it is applied to the outputs as well.
|
||||
|
||||
Arguments:
|
||||
rank: An integer, the rank of the convolution,
|
||||
e.g. "2" for 2D convolution.
|
||||
filters: Integer, the dimensionality of the output space
|
||||
(i.e. the number output of filters in the convolution).
|
||||
kernel_size: An integer or tuple/list of n integers, specifying the
|
||||
dimensions of the convolution window.
|
||||
strides: An integer or tuple/list of n integers,
|
||||
specifying the strides of the convolution.
|
||||
Specifying any stride value != 1 is incompatible with specifying
|
||||
any `dilation_rate` value != 1.
|
||||
padding: One of `"valid"` or `"same"` (case-insensitive).
|
||||
data_format: A string,
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, ..., channels)` while `channels_first` corresponds to
|
||||
inputs with shape `(batch, channels, ...)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
dilation_rate: An integer or tuple/list of n integers, specifying
|
||||
the dilation rate to use for dilated convolution.
|
||||
Currently, specifying any `dilation_rate` value != 1 is
|
||||
incompatible with specifying any `strides` value != 1.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
use_bias: Boolean, whether the layer uses a bias vector.
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix.
|
||||
bias_initializer: Initializer for the bias vector.
|
||||
kernel_regularizer: Regularizer function applied to
|
||||
the `kernel` weights matrix.
|
||||
bias_regularizer: Regularizer function applied to the bias vector.
|
||||
activity_regularizer: Regularizer function applied to
|
||||
the output of the layer (its "activation")..
|
||||
kernel_constraint: Constraint function applied to the kernel matrix.
|
||||
bias_constraint: Constraint function applied to the bias vector.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rank,
|
||||
filters,
|
||||
kernel_size,
|
||||
strides=1,
|
||||
padding='valid',
|
||||
data_format=None,
|
||||
dilation_rate=1,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
kernel_initializer='glorot_uniform',
|
||||
bias_initializer='zeros',
|
||||
kernel_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
super(_Conv, self).__init__(**kwargs)
|
||||
self.rank = rank
|
||||
self.filters = filters
|
||||
self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank,
|
||||
'kernel_size')
|
||||
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
|
||||
self.padding = conv_utils.normalize_padding(padding)
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank,
|
||||
'dilation_rate')
|
||||
self.activation = activations.get(activation)
|
||||
self.use_bias = use_bias
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
self.bias_initializer = initializers.get(bias_initializer)
|
||||
self.kernel_regularizer = regularizers.get(kernel_regularizer)
|
||||
self.bias_regularizer = regularizers.get(bias_regularizer)
|
||||
self.activity_regularizer = regularizers.get(activity_regularizer)
|
||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
self.input_spec = InputSpec(ndim=self.rank + 2)
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if self.data_format == 'channels_first':
|
||||
channel_axis = 1
|
||||
else:
|
||||
channel_axis = -1
|
||||
if input_shape[channel_axis] is None:
|
||||
raise ValueError('The channel dimension of the inputs '
|
||||
'should be defined. Found `None`.')
|
||||
input_dim = input_shape[channel_axis]
|
||||
kernel_shape = self.kernel_size + (input_dim, self.filters)
|
||||
|
||||
self.kernel = self.add_weight(
|
||||
shape=kernel_shape,
|
||||
initializer=self.kernel_initializer,
|
||||
name='kernel',
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint)
|
||||
if self.use_bias:
|
||||
self.bias = self.add_weight(
|
||||
shape=(self.filters,),
|
||||
initializer=self.bias_initializer,
|
||||
name='bias',
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint)
|
||||
else:
|
||||
self.bias = None
|
||||
# Set input spec.
|
||||
self.input_spec = InputSpec(
|
||||
ndim=self.rank + 2, axes={channel_axis: input_dim})
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
if self.rank == 1:
|
||||
outputs = K.conv1d(
|
||||
inputs,
|
||||
self.kernel,
|
||||
strides=self.strides[0],
|
||||
padding=self.padding,
|
||||
data_format=self.data_format,
|
||||
dilation_rate=self.dilation_rate[0])
|
||||
if self.rank == 2:
|
||||
outputs = K.conv2d(
|
||||
inputs,
|
||||
self.kernel,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format,
|
||||
dilation_rate=self.dilation_rate)
|
||||
if self.rank == 3:
|
||||
outputs = K.conv3d(
|
||||
inputs,
|
||||
self.kernel,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format,
|
||||
dilation_rate=self.dilation_rate)
|
||||
|
||||
if self.use_bias:
|
||||
outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
|
||||
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs)
|
||||
return outputs
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if self.data_format == 'channels_last':
|
||||
space = input_shape[1:-1]
|
||||
new_space = []
|
||||
for i in range(len(space)):
|
||||
new_dim = conv_utils.conv_output_length(
|
||||
space[i],
|
||||
self.kernel_size[i],
|
||||
padding=self.padding,
|
||||
stride=self.strides[i],
|
||||
dilation=self.dilation_rate[i])
|
||||
new_space.append(new_dim)
|
||||
return tensor_shape.TensorShape([input_shape[0]] + new_space +
|
||||
[self.filters])
|
||||
else:
|
||||
space = input_shape[2:]
|
||||
new_space = []
|
||||
for i in range(len(space)):
|
||||
new_dim = conv_utils.conv_output_length(
|
||||
space[i],
|
||||
self.kernel_size[i],
|
||||
padding=self.padding,
|
||||
stride=self.strides[i],
|
||||
dilation=self.dilation_rate[i])
|
||||
new_space.append(new_dim)
|
||||
return tensor_shape.TensorShape([input_shape[0], self.filters] +
|
||||
new_space)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'rank':
|
||||
self.rank,
|
||||
'filters':
|
||||
self.filters,
|
||||
'kernel_size':
|
||||
self.kernel_size,
|
||||
'strides':
|
||||
self.strides,
|
||||
'padding':
|
||||
self.padding,
|
||||
'data_format':
|
||||
self.data_format,
|
||||
'dilation_rate':
|
||||
self.dilation_rate,
|
||||
'activation':
|
||||
activations.serialize(self.activation),
|
||||
'use_bias':
|
||||
self.use_bias,
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer':
|
||||
regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'kernel_constraint':
|
||||
constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint':
|
||||
constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(_Conv, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class Conv1D(_Conv):
|
||||
"""1D convolution layer (e.g.
|
||||
|
||||
temporal convolution).
|
||||
class Conv1D(tf_convolutional_layers.Conv1D, Layer):
|
||||
"""1D convolution layer (e.g. temporal convolution).
|
||||
|
||||
This layer creates a convolution kernel that is convolved
|
||||
with the layer input over a single spatial (or temporal) dimension
|
||||
@ -336,33 +114,55 @@ class Conv1D(_Conv):
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
super(Conv1D, self).__init__(
|
||||
rank=1,
|
||||
filters=filters,
|
||||
kernel_size=kernel_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format='channels_last',
|
||||
dilation_rate=dilation_rate,
|
||||
activation=activation,
|
||||
activation=activations.get(activation),
|
||||
use_bias=use_bias,
|
||||
kernel_initializer=kernel_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
kernel_regularizer=kernel_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
kernel_constraint=kernel_constraint,
|
||||
bias_constraint=bias_constraint,
|
||||
kernel_initializer=initializers.get(kernel_initializer),
|
||||
bias_initializer=initializers.get(bias_initializer),
|
||||
kernel_regularizer=regularizers.get(kernel_regularizer),
|
||||
bias_regularizer=regularizers.get(bias_regularizer),
|
||||
activity_regularizer=regularizers.get(activity_regularizer),
|
||||
**kwargs)
|
||||
self.input_spec = InputSpec(ndim=3)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
super(Conv1D, self).build(input_shape)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
if self.kernel_constraint:
|
||||
self.constraints[self.kernel] = self.kernel_constraint
|
||||
if self.use_bias and self.bias_constraint:
|
||||
self.constraints[self.bias] = self.bias_constraint
|
||||
|
||||
def get_config(self):
|
||||
config = super(Conv1D, self).get_config()
|
||||
config.pop('rank')
|
||||
config.pop('data_format')
|
||||
return config
|
||||
config = {
|
||||
'filters': self.filters,
|
||||
'kernel_size': self.kernel_size,
|
||||
'strides': self.strides,
|
||||
'padding': self.padding,
|
||||
'dilation_rate': self.dilation_rate,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'kernel_constraint': constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(Conv1D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class Conv2D(_Conv):
|
||||
class Conv2D(tf_convolutional_layers.Conv2D, Layer):
|
||||
"""2D convolution layer (e.g. spatial convolution over images).
|
||||
|
||||
This layer creates a convolution kernel that is convolved
|
||||
@ -452,36 +252,60 @@ class Conv2D(_Conv):
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
if data_format is None:
|
||||
data_format = K.image_data_format()
|
||||
super(Conv2D, self).__init__(
|
||||
rank=2,
|
||||
filters=filters,
|
||||
kernel_size=kernel_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
dilation_rate=dilation_rate,
|
||||
activation=activation,
|
||||
activation=activations.get(activation),
|
||||
use_bias=use_bias,
|
||||
kernel_initializer=kernel_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
kernel_regularizer=kernel_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
kernel_constraint=kernel_constraint,
|
||||
bias_constraint=bias_constraint,
|
||||
kernel_initializer=initializers.get(kernel_initializer),
|
||||
bias_initializer=initializers.get(bias_initializer),
|
||||
kernel_regularizer=regularizers.get(kernel_regularizer),
|
||||
bias_regularizer=regularizers.get(bias_regularizer),
|
||||
activity_regularizer=regularizers.get(activity_regularizer),
|
||||
**kwargs)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
super(Conv2D, self).build(input_shape)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
if self.kernel_constraint:
|
||||
self.constraints[self.kernel] = self.kernel_constraint
|
||||
if self.use_bias and self.bias_constraint:
|
||||
self.constraints[self.bias] = self.bias_constraint
|
||||
|
||||
def get_config(self):
|
||||
config = super(Conv2D, self).get_config()
|
||||
config.pop('rank')
|
||||
return config
|
||||
config = {
|
||||
'filters': self.filters,
|
||||
'kernel_size': self.kernel_size,
|
||||
'strides': self.strides,
|
||||
'padding': self.padding,
|
||||
'data_format': self.data_format,
|
||||
'dilation_rate': self.dilation_rate,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'kernel_constraint': constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(Conv2D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class Conv3D(_Conv):
|
||||
"""3D convolution layer (e.g.
|
||||
|
||||
spatial convolution over volumes).
|
||||
class Conv3D(tf_convolutional_layers.Conv3D, Layer):
|
||||
"""3D convolution layer (e.g. spatial convolution over volumes).
|
||||
|
||||
This layer creates a convolution kernel that is convolved
|
||||
with the layer input to produce a tensor of
|
||||
@ -577,33 +401,59 @@ class Conv3D(_Conv):
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
if data_format is None:
|
||||
data_format = K.image_data_format()
|
||||
super(Conv3D, self).__init__(
|
||||
rank=3,
|
||||
filters=filters,
|
||||
kernel_size=kernel_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
dilation_rate=dilation_rate,
|
||||
activation=activation,
|
||||
activation=activations.get(activation),
|
||||
use_bias=use_bias,
|
||||
kernel_initializer=kernel_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
kernel_regularizer=kernel_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
kernel_constraint=kernel_constraint,
|
||||
bias_constraint=bias_constraint,
|
||||
kernel_initializer=initializers.get(kernel_initializer),
|
||||
bias_initializer=initializers.get(bias_initializer),
|
||||
kernel_regularizer=regularizers.get(kernel_regularizer),
|
||||
bias_regularizer=regularizers.get(bias_regularizer),
|
||||
activity_regularizer=regularizers.get(activity_regularizer),
|
||||
**kwargs)
|
||||
self.input_spec = InputSpec(ndim=5)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
super(Conv3D, self).build(input_shape)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
if self.kernel_constraint:
|
||||
self.constraints[self.kernel] = self.kernel_constraint
|
||||
if self.use_bias and self.bias_constraint:
|
||||
self.constraints[self.bias] = self.bias_constraint
|
||||
|
||||
def get_config(self):
|
||||
config = super(Conv3D, self).get_config()
|
||||
config.pop('rank')
|
||||
return config
|
||||
config = {
|
||||
'filters': self.filters,
|
||||
'kernel_size': self.kernel_size,
|
||||
'strides': self.strides,
|
||||
'padding': self.padding,
|
||||
'data_format': self.data_format,
|
||||
'dilation_rate': self.dilation_rate,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'kernel_constraint': constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(Conv3D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class Conv2DTranspose(Conv2D):
|
||||
class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer):
|
||||
"""Transposed convolution layer (sometimes called Deconvolution).
|
||||
|
||||
The need for transposed convolutions generally arises
|
||||
@ -699,121 +549,57 @@ class Conv2DTranspose(Conv2D):
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
if data_format is None:
|
||||
data_format = K.image_data_format()
|
||||
super(Conv2DTranspose, self).__init__(
|
||||
filters,
|
||||
kernel_size,
|
||||
filters=filters,
|
||||
kernel_size=kernel_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
activation=activation,
|
||||
activation=activations.get(activation),
|
||||
use_bias=use_bias,
|
||||
kernel_initializer=kernel_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
kernel_regularizer=kernel_regularizer,
|
||||
bias_regularizer=bias_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
kernel_constraint=kernel_constraint,
|
||||
bias_constraint=bias_constraint,
|
||||
kernel_initializer=initializers.get(kernel_initializer),
|
||||
bias_initializer=initializers.get(bias_initializer),
|
||||
kernel_regularizer=regularizers.get(kernel_regularizer),
|
||||
bias_regularizer=regularizers.get(bias_regularizer),
|
||||
activity_regularizer=regularizers.get(activity_regularizer),
|
||||
**kwargs)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if len(input_shape) != 4:
|
||||
raise ValueError(
|
||||
'Inputs should have rank ' + str(4) + '; Received input shape:',
|
||||
str(input_shape))
|
||||
if self.data_format == 'channels_first':
|
||||
channel_axis = 1
|
||||
else:
|
||||
channel_axis = -1
|
||||
if input_shape[channel_axis] is None:
|
||||
raise ValueError('The channel dimension of the inputs '
|
||||
'should be defined. Found `None`.')
|
||||
input_dim = input_shape[channel_axis]
|
||||
kernel_shape = self.kernel_size + (self.filters, input_dim)
|
||||
|
||||
self.kernel = self.add_weight(
|
||||
shape=kernel_shape,
|
||||
initializer=self.kernel_initializer,
|
||||
name='kernel',
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint)
|
||||
if self.use_bias:
|
||||
self.bias = self.add_weight(
|
||||
shape=(self.filters,),
|
||||
initializer=self.bias_initializer,
|
||||
name='bias',
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint)
|
||||
else:
|
||||
self.bias = None
|
||||
# Set input spec.
|
||||
self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
input_shape = K.shape(inputs)
|
||||
batch_size = input_shape[0]
|
||||
if self.data_format == 'channels_first':
|
||||
h_axis, w_axis = 2, 3
|
||||
else:
|
||||
h_axis, w_axis = 1, 2
|
||||
|
||||
height, width = input_shape[h_axis], input_shape[w_axis]
|
||||
kernel_h, kernel_w = self.kernel_size
|
||||
stride_h, stride_w = self.strides
|
||||
|
||||
# Infer the dynamic output shape:
|
||||
out_height = conv_utils.deconv_length(height, stride_h, kernel_h,
|
||||
self.padding)
|
||||
out_width = conv_utils.deconv_length(width, stride_w, kernel_w,
|
||||
self.padding)
|
||||
if self.data_format == 'channels_first':
|
||||
output_shape = (batch_size, self.filters, out_height, out_width)
|
||||
else:
|
||||
output_shape = (batch_size, out_height, out_width, self.filters)
|
||||
|
||||
outputs = K.conv2d_transpose(
|
||||
inputs,
|
||||
self.kernel,
|
||||
output_shape,
|
||||
self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format)
|
||||
|
||||
if self.bias:
|
||||
outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
|
||||
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs)
|
||||
return outputs
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
output_shape = list(input_shape)
|
||||
if self.data_format == 'channels_first':
|
||||
c_axis, h_axis, w_axis = 1, 2, 3
|
||||
else:
|
||||
c_axis, h_axis, w_axis = 3, 1, 2
|
||||
|
||||
kernel_h, kernel_w = self.kernel_size
|
||||
stride_h, stride_w = self.strides
|
||||
|
||||
output_shape[c_axis] = self.filters
|
||||
output_shape[h_axis] = conv_utils.deconv_length(
|
||||
output_shape[h_axis], stride_h, kernel_h, self.padding)
|
||||
output_shape[w_axis] = conv_utils.deconv_length(
|
||||
output_shape[w_axis], stride_w, kernel_w, self.padding)
|
||||
return tensor_shape.TensorShape(output_shape)
|
||||
super(Conv2DTranspose, self).build(input_shape)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
if self.kernel_constraint:
|
||||
self.constraints[self.kernel] = self.kernel_constraint
|
||||
if self.use_bias and self.bias_constraint:
|
||||
self.constraints[self.bias] = self.bias_constraint
|
||||
|
||||
def get_config(self):
|
||||
config = super(Conv2DTranspose, self).get_config()
|
||||
config.pop('dilation_rate')
|
||||
return config
|
||||
config = {
|
||||
'filters': self.filters,
|
||||
'kernel_size': self.kernel_size,
|
||||
'strides': self.strides,
|
||||
'padding': self.padding,
|
||||
'data_format': self.data_format,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'kernel_constraint': constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(Conv2DTranspose, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class SeparableConv2D(Conv2D):
|
||||
class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer):
|
||||
"""Depthwise separable 2D convolution.
|
||||
|
||||
Separable convolutions consist in first performing
|
||||
@ -909,126 +695,68 @@ class SeparableConv2D(Conv2D):
|
||||
pointwise_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
if data_format is None:
|
||||
data_format = K.image_data_format()
|
||||
super(SeparableConv2D, self).__init__(
|
||||
filters=filters,
|
||||
kernel_size=kernel_size,
|
||||
strides=strides,
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
activation=activation,
|
||||
activation=activations.get(activation),
|
||||
use_bias=use_bias,
|
||||
bias_regularizer=bias_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
bias_constraint=bias_constraint,
|
||||
depthwise_initializer=initializers.get(depthwise_initializer),
|
||||
pointwise_initializer=initializers.get(pointwise_initializer),
|
||||
bias_initializer=initializers.get(bias_initializer),
|
||||
depthwise_regularizer=regularizers.get(depthwise_regularizer),
|
||||
pointwise_regularizer=regularizers.get(pointwise_regularizer),
|
||||
bias_regularizer=regularizers.get(bias_regularizer),
|
||||
activity_regularizer=regularizers.get(activity_regularizer),
|
||||
**kwargs)
|
||||
self.depth_multiplier = depth_multiplier
|
||||
self.depthwise_initializer = initializers.get(depthwise_initializer)
|
||||
self.pointwise_initializer = initializers.get(pointwise_initializer)
|
||||
self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
|
||||
self.pointwise_regularizer = regularizers.get(pointwise_regularizer)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
self.depthwise_constraint = constraints.get(depthwise_constraint)
|
||||
self.pointwise_constraint = constraints.get(pointwise_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if len(input_shape) < 4:
|
||||
raise ValueError('Inputs to `SeparableConv2D` should have rank 4. '
|
||||
'Received input shape:', str(input_shape))
|
||||
if self.data_format == 'channels_first':
|
||||
channel_axis = 1
|
||||
else:
|
||||
channel_axis = 3
|
||||
if input_shape[channel_axis] is None:
|
||||
raise ValueError('The channel dimension of the inputs to '
|
||||
'`SeparableConv2D` '
|
||||
'should be defined. Found `None`.')
|
||||
input_dim = int(input_shape[channel_axis])
|
||||
depthwise_kernel_shape = (self.kernel_size[0], self.kernel_size[1],
|
||||
input_dim, self.depth_multiplier)
|
||||
pointwise_kernel_shape = (1, 1, self.depth_multiplier * input_dim,
|
||||
self.filters)
|
||||
|
||||
self.depthwise_kernel = self.add_weight(
|
||||
shape=depthwise_kernel_shape,
|
||||
initializer=self.depthwise_initializer,
|
||||
name='depthwise_kernel',
|
||||
regularizer=self.depthwise_regularizer,
|
||||
constraint=self.depthwise_constraint)
|
||||
self.pointwise_kernel = self.add_weight(
|
||||
shape=pointwise_kernel_shape,
|
||||
initializer=self.pointwise_initializer,
|
||||
name='pointwise_kernel',
|
||||
regularizer=self.pointwise_regularizer,
|
||||
constraint=self.pointwise_constraint)
|
||||
|
||||
if self.use_bias:
|
||||
self.bias = self.add_weight(
|
||||
shape=(self.filters,),
|
||||
initializer=self.bias_initializer,
|
||||
name='bias',
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint)
|
||||
else:
|
||||
self.bias = None
|
||||
# Set input spec.
|
||||
self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
outputs = K.separable_conv2d(
|
||||
inputs,
|
||||
self.depthwise_kernel,
|
||||
self.pointwise_kernel,
|
||||
data_format=self.data_format,
|
||||
strides=self.strides,
|
||||
padding=self.padding)
|
||||
|
||||
if self.bias:
|
||||
outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
|
||||
|
||||
if self.activation is not None:
|
||||
return self.activation(outputs)
|
||||
return outputs
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if self.data_format == 'channels_first':
|
||||
rows = input_shape[2]
|
||||
cols = input_shape[3]
|
||||
else:
|
||||
rows = input_shape[1]
|
||||
cols = input_shape[2]
|
||||
|
||||
rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
|
||||
self.padding, self.strides[0])
|
||||
cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
|
||||
self.padding, self.strides[1])
|
||||
if self.data_format == 'channels_first':
|
||||
return tensor_shape.TensorShape(
|
||||
[input_shape[0], self.filters, rows, cols])
|
||||
else:
|
||||
return tensor_shape.TensorShape(
|
||||
[input_shape[0], rows, cols, self.filters])
|
||||
super(SeparableConv2D, self).build(input_shape)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
if self.depthwise_constraint:
|
||||
self.constraints[self.depthwise_kernel] = self.depthwise_constraint
|
||||
if self.pointwise_constraint:
|
||||
self.constraints[self.pointwise_kernel] = self.pointwise_constraint
|
||||
if self.use_bias and self.bias_constraint:
|
||||
self.constraints[self.bias] = self.bias_constraint
|
||||
|
||||
def get_config(self):
|
||||
config = super(SeparableConv2D, self).get_config()
|
||||
config.pop('kernel_initializer')
|
||||
config.pop('kernel_regularizer')
|
||||
config.pop('kernel_constraint')
|
||||
config['depth_multiplier'] = self.depth_multiplier
|
||||
config['depthwise_initializer'] = initializers.serialize(
|
||||
self.depthwise_initializer)
|
||||
config['pointwise_initializer'] = initializers.serialize(
|
||||
self.pointwise_initializer)
|
||||
config['depthwise_regularizer'] = regularizers.serialize(
|
||||
self.depthwise_regularizer)
|
||||
config['pointwise_regularizer'] = regularizers.serialize(
|
||||
self.pointwise_regularizer)
|
||||
config['depthwise_constraint'] = constraints.serialize(
|
||||
self.depthwise_constraint)
|
||||
config['pointwise_constraint'] = constraints.serialize(
|
||||
self.pointwise_constraint)
|
||||
return config
|
||||
config = {
|
||||
'filters': self.filters,
|
||||
'kernel_size': self.kernel_size,
|
||||
'strides': self.strides,
|
||||
'padding': self.padding,
|
||||
'data_format': self.data_format,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'depthwise_initializer': initializers.serialize(
|
||||
self.depthwise_initializer),
|
||||
'pointwise_initializer': initializers.serialize(
|
||||
self.pointwise_initializer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'depthwise_regularizer': regularizers.serialize(
|
||||
self.depthwise_regularizer),
|
||||
'pointwise_regularizer': regularizers.serialize(
|
||||
self.pointwise_regularizer),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'depthwise_constraint': constraints.serialize(
|
||||
self.depthwise_constraint),
|
||||
'pointwise_constraint': constraints.serialize(
|
||||
self.pointwise_constraint),
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(SeparableConv2D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class UpSampling1D(Layer):
|
||||
|
@ -27,24 +27,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
class Convolution1DTest(test.TestCase):
|
||||
|
||||
def test_causal_dilated_conv1d(self):
|
||||
# Causal:
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Conv1D,
|
||||
input_data=np.reshape(np.arange(4, dtype='float32'), (1, 4, 1)),
|
||||
kwargs={
|
||||
'filters': 1,
|
||||
'kernel_size': 2,
|
||||
'dilation_rate': 1,
|
||||
'padding': 'causal',
|
||||
'kernel_initializer': 'ones',
|
||||
'use_bias': False,
|
||||
},
|
||||
expected_output=[[[0], [1], [3], [5]]])
|
||||
|
||||
def test_dilated_conv1d(self):
|
||||
# Non-causal:
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.Conv1D,
|
||||
|
@ -85,7 +85,7 @@ class Masking(Layer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class Dropout(Layer):
|
||||
class Dropout(tf_core_layers.Dropout, Layer):
|
||||
"""Applies Dropout to the input.
|
||||
|
||||
Dropout consists in randomly setting
|
||||
@ -104,24 +104,18 @@ class Dropout(Layer):
|
||||
"""
|
||||
|
||||
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
|
||||
super(Dropout, self).__init__(**kwargs)
|
||||
self.rate = min(1., max(0., rate))
|
||||
self.noise_shape = noise_shape
|
||||
self.seed = seed
|
||||
self.supports_masking = True
|
||||
|
||||
def _get_noise_shape(self, _):
|
||||
return self.noise_shape
|
||||
# Inheritance call order:
|
||||
# 1) tf.layers.Dropout, 2) keras.layers.Layer, 3) tf.layers.Layer
|
||||
super(Dropout, self).__init__(**kwargs)
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if 0. < self.rate < 1.:
|
||||
noise_shape = self._get_noise_shape(inputs)
|
||||
|
||||
def dropped_inputs():
|
||||
return K.dropout(inputs, self.rate, noise_shape, seed=self.seed)
|
||||
|
||||
return K.in_train_phase(dropped_inputs, inputs, training=training)
|
||||
return inputs
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
output = super(Dropout, self).call(inputs, training=training)
|
||||
if training is K.learning_phase():
|
||||
output._uses_learning_phase = True # pylint: disable=protected-access
|
||||
return output
|
||||
|
||||
def get_config(self):
|
||||
config = {'rate': self.rate}
|
||||
@ -726,45 +720,32 @@ class Dense(tf_core_layers.Dense, Layer):
|
||||
bias_regularizer=regularizers.get(bias_regularizer),
|
||||
activity_regularizer=regularizers.get(activity_regularizer),
|
||||
**kwargs)
|
||||
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
self.input_spec = InputSpec(min_ndim=2)
|
||||
self.supports_masking = True
|
||||
|
||||
def build(self, input_shape):
|
||||
assert len(input_shape) >= 2
|
||||
input_dim = input_shape[-1]
|
||||
super(Dense, self).build(input_shape)
|
||||
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
if self.kernel_constraint:
|
||||
self.constraints[self.kernel] = self.kernel_constraint
|
||||
if self.use_bias and self.bias_constraint:
|
||||
self.constraints[self.bias] = self.bias_constraint
|
||||
self.built = True
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'units':
|
||||
self.units,
|
||||
'activation':
|
||||
activations.serialize(self.activation),
|
||||
'use_bias':
|
||||
self.use_bias,
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer':
|
||||
regularizers.serialize(self.bias_regularizer),
|
||||
'units': self.units,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'kernel_constraint':
|
||||
constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint':
|
||||
constraints.serialize(self.bias_constraint)
|
||||
'kernel_constraint': constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(Dense, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
@ -22,12 +22,11 @@ from tensorflow.contrib.keras.python.keras import backend as K
|
||||
from tensorflow.contrib.keras.python.keras import constraints
|
||||
from tensorflow.contrib.keras.python.keras import initializers
|
||||
from tensorflow.contrib.keras.python.keras import regularizers
|
||||
from tensorflow.contrib.keras.python.keras.engine import InputSpec
|
||||
from tensorflow.contrib.keras.python.keras.engine import Layer
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import normalization as tf_normalization_layers
|
||||
|
||||
|
||||
class BatchNormalization(Layer):
|
||||
class BatchNormalization(tf_normalization_layers.BatchNormalization, Layer):
|
||||
"""Batch normalization layer (Ioffe and Szegedy, 2014).
|
||||
|
||||
Normalize the activations of the previous layer at each batch,
|
||||
@ -86,148 +85,59 @@ class BatchNormalization(Layer):
|
||||
beta_constraint=None,
|
||||
gamma_constraint=None,
|
||||
**kwargs):
|
||||
super(BatchNormalization, self).__init__(**kwargs)
|
||||
self.supports_masking = True
|
||||
self.axis = axis
|
||||
self.momentum = momentum
|
||||
self.epsilon = epsilon
|
||||
self.center = center
|
||||
self.scale = scale
|
||||
self.beta_initializer = initializers.get(beta_initializer)
|
||||
self.gamma_initializer = initializers.get(gamma_initializer)
|
||||
self.moving_mean_initializer = initializers.get(moving_mean_initializer)
|
||||
self.moving_variance_initializer = initializers.get(
|
||||
moving_variance_initializer)
|
||||
self.beta_regularizer = regularizers.get(beta_regularizer)
|
||||
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
||||
super(BatchNormalization, self).__init__(
|
||||
axis=axis,
|
||||
momentum=momentum,
|
||||
epsilon=epsilon,
|
||||
center=center,
|
||||
scale=scale,
|
||||
beta_initializer=initializers.get(beta_initializer),
|
||||
gamma_initializer=initializers.get(gamma_initializer),
|
||||
moving_mean_initializer=initializers.get(moving_mean_initializer),
|
||||
moving_variance_initializer=initializers.get(
|
||||
moving_variance_initializer),
|
||||
beta_regularizer=regularizers.get(beta_regularizer),
|
||||
gamma_regularizer=regularizers.get(gamma_regularizer),
|
||||
**kwargs
|
||||
)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
self.beta_constraint = constraints.get(beta_constraint)
|
||||
self.gamma_constraint = constraints.get(gamma_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
dim = input_shape[self.axis]
|
||||
if dim is None:
|
||||
raise ValueError('Axis ' + str(self.axis) + ' of '
|
||||
'input tensor should have a defined dimension '
|
||||
'but the layer received an input with shape ' +
|
||||
str(input_shape) + '.')
|
||||
self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: dim})
|
||||
shape = (dim,)
|
||||
|
||||
if self.scale:
|
||||
self.gamma = self.add_weight(
|
||||
shape=shape,
|
||||
name='gamma',
|
||||
initializer=self.gamma_initializer,
|
||||
regularizer=self.gamma_regularizer,
|
||||
constraint=self.gamma_constraint)
|
||||
else:
|
||||
self.gamma = None
|
||||
if self.center:
|
||||
self.beta = self.add_weight(
|
||||
shape=shape,
|
||||
name='beta',
|
||||
initializer=self.beta_initializer,
|
||||
regularizer=self.beta_regularizer,
|
||||
constraint=self.beta_constraint)
|
||||
else:
|
||||
self.beta = None
|
||||
self.moving_mean = self.add_weight(
|
||||
shape=shape,
|
||||
name='moving_mean',
|
||||
initializer=self.moving_mean_initializer,
|
||||
trainable=False)
|
||||
self.moving_variance = self.add_weight(
|
||||
shape=shape,
|
||||
name='moving_variance',
|
||||
initializer=self.moving_variance_initializer,
|
||||
trainable=False)
|
||||
self.built = True
|
||||
super(BatchNormalization, self).build(input_shape)
|
||||
# TODO(fchollet): move weight constraint support to core layers.
|
||||
if self.center and self.beta_constraint:
|
||||
self.constraints[self.beta] = self.beta_constraint
|
||||
if self.scale and self.gamma_constraint:
|
||||
self.constraints[self.gamma] = self.gamma_constraint
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
input_shape = inputs.get_shape().as_list()
|
||||
# Prepare broadcasting shape.
|
||||
ndim = len(input_shape)
|
||||
reduction_axes = list(range(len(input_shape)))
|
||||
del reduction_axes[self.axis]
|
||||
broadcast_shape = [1] * len(input_shape)
|
||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
||||
|
||||
# Determines whether broadcasting is needed.
|
||||
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
|
||||
|
||||
normed, mean, variance = K.normalize_batch_in_training(
|
||||
inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)
|
||||
|
||||
if training in {0, False}:
|
||||
return normed
|
||||
else:
|
||||
self.add_update([
|
||||
K.moving_average_update(self.moving_mean, mean, self.momentum),
|
||||
K.moving_average_update(self.moving_variance, variance, self.momentum)
|
||||
], inputs)
|
||||
|
||||
def normalize_inference():
|
||||
if needs_broadcasting:
|
||||
# In this case we must explicitly broadcast all parameters.
|
||||
broadcast_moving_mean = K.reshape(self.moving_mean, broadcast_shape)
|
||||
broadcast_moving_variance = K.reshape(self.moving_variance,
|
||||
broadcast_shape)
|
||||
if self.center:
|
||||
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
||||
else:
|
||||
broadcast_beta = None
|
||||
if self.scale:
|
||||
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
||||
else:
|
||||
broadcast_gamma = None
|
||||
return K.batch_normalization(
|
||||
inputs,
|
||||
broadcast_moving_mean,
|
||||
broadcast_moving_variance,
|
||||
broadcast_beta,
|
||||
broadcast_gamma,
|
||||
epsilon=self.epsilon)
|
||||
else:
|
||||
return K.batch_normalization(
|
||||
inputs,
|
||||
self.moving_mean,
|
||||
self.moving_variance,
|
||||
self.beta,
|
||||
self.gamma,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
# Pick the normalized form corresponding to the training phase.
|
||||
return K.in_train_phase(normed, normalize_inference, training=training)
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
output = super(BatchNormalization, self).call(inputs, training=training)
|
||||
if training is K.learning_phase():
|
||||
output._uses_learning_phase = True # pylint: disable=protected-access
|
||||
return output
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'axis':
|
||||
self.axis,
|
||||
'momentum':
|
||||
self.momentum,
|
||||
'epsilon':
|
||||
self.epsilon,
|
||||
'center':
|
||||
self.center,
|
||||
'scale':
|
||||
self.scale,
|
||||
'beta_initializer':
|
||||
initializers.serialize(self.beta_initializer),
|
||||
'gamma_initializer':
|
||||
initializers.serialize(self.gamma_initializer),
|
||||
'axis': self.axis,
|
||||
'momentum': self.momentum,
|
||||
'epsilon': self.epsilon,
|
||||
'center': self.center,
|
||||
'scale': self.scale,
|
||||
'beta_initializer': initializers.serialize(self.beta_initializer),
|
||||
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
||||
'moving_mean_initializer':
|
||||
initializers.serialize(self.moving_mean_initializer),
|
||||
'moving_variance_initializer':
|
||||
initializers.serialize(self.moving_variance_initializer),
|
||||
'beta_regularizer':
|
||||
regularizers.serialize(self.beta_regularizer),
|
||||
'gamma_regularizer':
|
||||
regularizers.serialize(self.gamma_regularizer),
|
||||
'beta_constraint':
|
||||
constraints.serialize(self.beta_constraint),
|
||||
'gamma_constraint':
|
||||
constraints.serialize(self.gamma_constraint)
|
||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
||||
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
||||
}
|
||||
base_config = super(BatchNormalization, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
@ -116,19 +116,21 @@ class NoiseLayersTest(test.TestCase):
|
||||
"""
|
||||
with self.test_session():
|
||||
# Test single layer reuse
|
||||
bn = keras.layers.BatchNormalization(input_shape=(10,))
|
||||
bn = keras.layers.BatchNormalization()
|
||||
x1 = keras.layers.Input(shape=(10,))
|
||||
bn(x1)
|
||||
_ = bn(x1)
|
||||
|
||||
x2 = keras.layers.Input(shape=(10,))
|
||||
y2 = bn(x2)
|
||||
|
||||
x = np.random.normal(loc=5.0, scale=10.0, size=(2, 10))
|
||||
model = keras.models.Model(x2, y2)
|
||||
assert len(model.updates) == 2
|
||||
|
||||
model.compile('sgd', 'mse')
|
||||
model.train_on_batch(x, x)
|
||||
|
||||
assert len(model.updates) == 2
|
||||
|
||||
# Test model-level reuse
|
||||
x3 = keras.layers.Input(shape=(10,))
|
||||
y3 = model(x3)
|
||||
|
@ -23,51 +23,10 @@ from tensorflow.contrib.keras.python.keras.engine import InputSpec
|
||||
from tensorflow.contrib.keras.python.keras.engine import Layer
|
||||
from tensorflow.contrib.keras.python.keras.utils import conv_utils
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.layers import pooling as tf_pooling_layers
|
||||
|
||||
|
||||
class _Pooling1D(Layer):
|
||||
"""Abstract class for different pooling 1D layers.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_size=2, strides=None, padding='valid', **kwargs):
|
||||
super(_Pooling1D, self).__init__(**kwargs)
|
||||
if strides is None:
|
||||
strides = pool_size
|
||||
self.pool_size = conv_utils.normalize_tuple(pool_size, 1, 'pool_size')
|
||||
self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
|
||||
self.padding = conv_utils.normalize_padding(padding)
|
||||
self.input_spec = InputSpec(ndim=3)
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0],
|
||||
self.padding, self.strides[0])
|
||||
return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = K.expand_dims(inputs, 2) # add dummy last dimension
|
||||
output = self._pooling_function(
|
||||
inputs=inputs,
|
||||
pool_size=self.pool_size + (1,),
|
||||
strides=self.strides + (1,),
|
||||
padding=self.padding,
|
||||
data_format='channels_last')
|
||||
return K.squeeze(output, 2) # remove dummy last dimension
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'strides': self.strides,
|
||||
'pool_size': self.pool_size,
|
||||
'padding': self.padding
|
||||
}
|
||||
base_config = super(_Pooling1D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class MaxPooling1D(_Pooling1D):
|
||||
class MaxPooling1D(tf_pooling_layers.MaxPooling1D, Layer):
|
||||
"""Max pooling operation for temporal data.
|
||||
|
||||
Arguments:
|
||||
@ -85,15 +44,21 @@ class MaxPooling1D(_Pooling1D):
|
||||
"""
|
||||
|
||||
def __init__(self, pool_size=2, strides=None, padding='valid', **kwargs):
|
||||
if strides is None:
|
||||
strides = pool_size
|
||||
super(MaxPooling1D, self).__init__(pool_size, strides, padding, **kwargs)
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
output = K.pool2d(
|
||||
inputs, pool_size, strides, padding, data_format, pool_mode='max')
|
||||
return output
|
||||
def get_config(self):
|
||||
config = {
|
||||
'strides': self.strides,
|
||||
'pool_size': self.pool_size,
|
||||
'padding': self.padding
|
||||
}
|
||||
base_config = super(MaxPooling1D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class AveragePooling1D(_Pooling1D):
|
||||
class AveragePooling1D(tf_pooling_layers.AveragePooling1D, Layer):
|
||||
"""Average pooling for temporal data.
|
||||
|
||||
Arguments:
|
||||
@ -111,78 +76,22 @@ class AveragePooling1D(_Pooling1D):
|
||||
"""
|
||||
|
||||
def __init__(self, pool_size=2, strides=None, padding='valid', **kwargs):
|
||||
if strides is None:
|
||||
strides = pool_size
|
||||
super(AveragePooling1D, self).__init__(pool_size, strides, padding,
|
||||
**kwargs)
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
output = K.pool2d(
|
||||
inputs, pool_size, strides, padding, data_format, pool_mode='avg')
|
||||
return output
|
||||
|
||||
|
||||
class _Pooling2D(Layer):
|
||||
"""Abstract class for different pooling 2D layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pool_size=(2, 2),
|
||||
strides=None,
|
||||
padding='valid',
|
||||
data_format=None,
|
||||
**kwargs):
|
||||
super(_Pooling2D, self).__init__(**kwargs)
|
||||
data_format = conv_utils.normalize_data_format(data_format)
|
||||
if strides is None:
|
||||
strides = pool_size
|
||||
self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size')
|
||||
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
|
||||
self.padding = conv_utils.normalize_padding(padding)
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if self.data_format == 'channels_first':
|
||||
rows = input_shape[2]
|
||||
cols = input_shape[3]
|
||||
else:
|
||||
rows = input_shape[1]
|
||||
cols = input_shape[2]
|
||||
rows = conv_utils.conv_output_length(rows, self.pool_size[0], self.padding,
|
||||
self.strides[0])
|
||||
cols = conv_utils.conv_output_length(cols, self.pool_size[1], self.padding,
|
||||
self.strides[1])
|
||||
if self.data_format == 'channels_first':
|
||||
return tensor_shape.TensorShape(
|
||||
[input_shape[0], input_shape[1], rows, cols])
|
||||
else:
|
||||
return tensor_shape.TensorShape(
|
||||
[input_shape[0], rows, cols, input_shape[3]])
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs):
|
||||
output = self._pooling_function(
|
||||
inputs=inputs,
|
||||
pool_size=self.pool_size,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format)
|
||||
return output
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'pool_size': self.pool_size,
|
||||
'padding': self.padding,
|
||||
'strides': self.strides,
|
||||
'data_format': self.data_format
|
||||
'pool_size': self.pool_size,
|
||||
'padding': self.padding
|
||||
}
|
||||
base_config = super(_Pooling2D, self).get_config()
|
||||
base_config = super(AveragePooling1D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class MaxPooling2D(_Pooling2D):
|
||||
class MaxPooling2D(tf_pooling_layers.MaxPooling2D, Layer):
|
||||
"""Max pooling operation for spatial data.
|
||||
|
||||
Arguments:
|
||||
@ -229,16 +138,25 @@ class MaxPooling2D(_Pooling2D):
|
||||
padding='valid',
|
||||
data_format=None,
|
||||
**kwargs):
|
||||
if data_format is None:
|
||||
data_format = K.image_data_format()
|
||||
if strides is None:
|
||||
strides = pool_size
|
||||
super(MaxPooling2D, self).__init__(pool_size, strides, padding, data_format,
|
||||
**kwargs)
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
output = K.pool2d(
|
||||
inputs, pool_size, strides, padding, data_format, pool_mode='max')
|
||||
return output
|
||||
def get_config(self):
|
||||
config = {
|
||||
'pool_size': self.pool_size,
|
||||
'padding': self.padding,
|
||||
'strides': self.strides,
|
||||
'data_format': self.data_format
|
||||
}
|
||||
base_config = super(MaxPooling2D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class AveragePooling2D(_Pooling2D):
|
||||
class AveragePooling2D(tf_pooling_layers.AveragePooling2D, Layer):
|
||||
"""Average pooling operation for spatial data.
|
||||
|
||||
Arguments:
|
||||
@ -285,68 +203,12 @@ class AveragePooling2D(_Pooling2D):
|
||||
padding='valid',
|
||||
data_format=None,
|
||||
**kwargs):
|
||||
super(AveragePooling2D, self).__init__(pool_size, strides, padding,
|
||||
data_format, **kwargs)
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
output = K.pool2d(
|
||||
inputs, pool_size, strides, padding, data_format, pool_mode='avg')
|
||||
return output
|
||||
|
||||
|
||||
class _Pooling3D(Layer):
|
||||
"""Abstract class for different pooling 3D layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pool_size=(2, 2, 2),
|
||||
strides=None,
|
||||
padding='valid',
|
||||
data_format=None,
|
||||
**kwargs):
|
||||
super(_Pooling3D, self).__init__(**kwargs)
|
||||
if data_format is None:
|
||||
data_format = K.image_data_format()
|
||||
if strides is None:
|
||||
strides = pool_size
|
||||
self.pool_size = conv_utils.normalize_tuple(pool_size, 3, 'pool_size')
|
||||
self.strides = conv_utils.normalize_tuple(strides, 3, 'strides')
|
||||
self.padding = conv_utils.normalize_padding(padding)
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
self.input_spec = InputSpec(ndim=5)
|
||||
|
||||
def _compute_output_shape(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if self.data_format == 'channels_first':
|
||||
len_dim1 = input_shape[2]
|
||||
len_dim2 = input_shape[3]
|
||||
len_dim3 = input_shape[4]
|
||||
else:
|
||||
len_dim1 = input_shape[1]
|
||||
len_dim2 = input_shape[2]
|
||||
len_dim3 = input_shape[3]
|
||||
len_dim1 = conv_utils.conv_output_length(len_dim1, self.pool_size[0],
|
||||
self.padding, self.strides[0])
|
||||
len_dim2 = conv_utils.conv_output_length(len_dim2, self.pool_size[1],
|
||||
self.padding, self.strides[1])
|
||||
len_dim3 = conv_utils.conv_output_length(len_dim3, self.pool_size[2],
|
||||
self.padding, self.strides[2])
|
||||
if self.data_format == 'channels_first':
|
||||
return tensor_shape.TensorShape(
|
||||
[input_shape[0], input_shape[1], len_dim1, len_dim2, len_dim3])
|
||||
else:
|
||||
return tensor_shape.TensorShape(
|
||||
[input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]])
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs):
|
||||
output = self._pooling_function(
|
||||
inputs=inputs,
|
||||
pool_size=self.pool_size,
|
||||
strides=self.strides,
|
||||
padding=self.padding,
|
||||
data_format=self.data_format)
|
||||
return output
|
||||
super(AveragePooling2D, self).__init__(pool_size, strides, padding,
|
||||
data_format, **kwargs)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
@ -355,11 +217,11 @@ class _Pooling3D(Layer):
|
||||
'strides': self.strides,
|
||||
'data_format': self.data_format
|
||||
}
|
||||
base_config = super(_Pooling3D, self).get_config()
|
||||
base_config = super(AveragePooling2D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class MaxPooling3D(_Pooling3D):
|
||||
class MaxPooling3D(tf_pooling_layers.MaxPooling3D, Layer):
|
||||
"""Max pooling operation for 3D data (spatial or spatio-temporal).
|
||||
|
||||
Arguments:
|
||||
@ -402,16 +264,25 @@ class MaxPooling3D(_Pooling3D):
|
||||
padding='valid',
|
||||
data_format=None,
|
||||
**kwargs):
|
||||
if data_format is None:
|
||||
data_format = K.image_data_format()
|
||||
if strides is None:
|
||||
strides = pool_size
|
||||
super(MaxPooling3D, self).__init__(pool_size, strides, padding, data_format,
|
||||
**kwargs)
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
output = K.pool3d(
|
||||
inputs, pool_size, strides, padding, data_format, pool_mode='max')
|
||||
return output
|
||||
def get_config(self):
|
||||
config = {
|
||||
'pool_size': self.pool_size,
|
||||
'padding': self.padding,
|
||||
'strides': self.strides,
|
||||
'data_format': self.data_format
|
||||
}
|
||||
base_config = super(MaxPooling3D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class AveragePooling3D(_Pooling3D):
|
||||
class AveragePooling3D(tf_pooling_layers.AveragePooling3D, Layer):
|
||||
"""Average pooling operation for 3D data (spatial or spatio-temporal).
|
||||
|
||||
Arguments:
|
||||
@ -454,13 +325,22 @@ class AveragePooling3D(_Pooling3D):
|
||||
padding='valid',
|
||||
data_format=None,
|
||||
**kwargs):
|
||||
if data_format is None:
|
||||
data_format = K.image_data_format()
|
||||
if strides is None:
|
||||
strides = pool_size
|
||||
super(AveragePooling3D, self).__init__(pool_size, strides, padding,
|
||||
data_format, **kwargs)
|
||||
|
||||
def _pooling_function(self, inputs, pool_size, strides, padding, data_format):
|
||||
output = K.pool3d(
|
||||
inputs, pool_size, strides, padding, data_format, pool_mode='avg')
|
||||
return output
|
||||
def get_config(self):
|
||||
config = {
|
||||
'pool_size': self.pool_size,
|
||||
'padding': self.padding,
|
||||
'strides': self.strides,
|
||||
'data_format': self.data_format
|
||||
}
|
||||
base_config = super(AveragePooling3D, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class _GlobalPooling1D(Layer):
|
||||
|
@ -71,7 +71,6 @@ class AvgPool2DTest(test.TestCase):
|
||||
height, width = 3, 6
|
||||
images = np.random.uniform(size=(5, 2, height, width))
|
||||
output = _layers.avg_pool2d(images, [3, 3], data_format='NCHW')
|
||||
self.assertEquals(output.op.name, 'AvgPool2D/AvgPool')
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, 2, 1, 2])
|
||||
|
||||
def testCollectOutputs(self):
|
||||
@ -2692,7 +2691,6 @@ class MaxPool2DTest(test.TestCase):
|
||||
height, width = 3, 6
|
||||
images = np.random.uniform(size=(5, 3, height, width)).astype(np.float32)
|
||||
output = _layers.max_pool2d(images, [3, 3], data_format='NCHW')
|
||||
self.assertEquals(output.op.name, 'MaxPool2D/MaxPool')
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, 3, 1, 2])
|
||||
|
||||
def testCollectOutputs(self):
|
||||
|
@ -503,6 +503,7 @@ tf_gen_op_libs(
|
||||
"control_flow_ops",
|
||||
"ctc_ops",
|
||||
"data_flow_ops",
|
||||
"dataset_ops",
|
||||
"function_ops",
|
||||
"functional_ops",
|
||||
"image_ops",
|
||||
@ -580,6 +581,7 @@ cc_library(
|
||||
":control_flow_ops_op_lib",
|
||||
":ctc_ops_op_lib",
|
||||
":data_flow_ops_op_lib",
|
||||
":dataset_ops_op_lib",
|
||||
":function_ops_op_lib",
|
||||
":functional_ops_op_lib",
|
||||
":image_ops_op_lib",
|
||||
@ -707,6 +709,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:ctc_ops",
|
||||
"//tensorflow/core/kernels:data_flow",
|
||||
"//tensorflow/core/kernels:dataset_ops",
|
||||
"//tensorflow/core/kernels:fake_quant_ops",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:image",
|
||||
@ -1408,11 +1411,6 @@ tf_cuda_library(
|
||||
"framework/**/*.cc",
|
||||
"util/**/*.h",
|
||||
"util/**/*.cc",
|
||||
] + [
|
||||
"graph/edgeset.h",
|
||||
"graph/edgeset.cc",
|
||||
"graph/graph.h",
|
||||
"graph/graph.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
@ -1553,6 +1551,8 @@ tf_cuda_library(
|
||||
"graph/colors.cc",
|
||||
"graph/control_flow.cc",
|
||||
"graph/costmodel.cc",
|
||||
"graph/edgeset.cc",
|
||||
"graph/graph.cc",
|
||||
"graph/graph_constructor.cc",
|
||||
"graph/graph_def_builder.cc",
|
||||
"graph/graph_partition.cc",
|
||||
|
@ -48,9 +48,9 @@ class ConstantFoldingTest : public ::testing::Test {
|
||||
TensorShape shape) {
|
||||
EXPECT_TRUE(n->IsConstant());
|
||||
const TensorProto* tensor_proto;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto));
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor_proto));
|
||||
DataType dtype;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
|
||||
Tensor t(dtype);
|
||||
EXPECT_TRUE(t.FromProto(*tensor_proto));
|
||||
test::ExpectClose(t, test::AsTensor(values, shape));
|
||||
@ -61,9 +61,9 @@ class ConstantFoldingTest : public ::testing::Test {
|
||||
TensorShape shape) {
|
||||
EXPECT_TRUE(n->IsConstant());
|
||||
const TensorProto* tensor_proto;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor_proto));
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor_proto));
|
||||
DataType dtype;
|
||||
TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype));
|
||||
TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype));
|
||||
Tensor t(dtype);
|
||||
EXPECT_TRUE(t.FromProto(*tensor_proto));
|
||||
test::ExpectTensorEqual<T>(t, test::AsTensor(values, shape));
|
||||
|
@ -92,28 +92,31 @@ bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) {
|
||||
}
|
||||
}
|
||||
}
|
||||
const AttrSlice attrs = node->attrs();
|
||||
string text;
|
||||
const NodeDef& def = node->def();
|
||||
string text = "";
|
||||
if (IsSend(node)) {
|
||||
string tensor_name;
|
||||
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
|
||||
TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
|
||||
string recv_device;
|
||||
TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
|
||||
text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
|
||||
"(", tensor_name, " @", recv_device);
|
||||
TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device));
|
||||
text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
|
||||
tensor_name, " @", recv_device);
|
||||
is_transfer_node = true;
|
||||
} else if (IsRecv(node)) {
|
||||
string tensor_name;
|
||||
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
|
||||
TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
|
||||
string send_device;
|
||||
TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
|
||||
text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
|
||||
"(", tensor_name, " @", send_device);
|
||||
TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device));
|
||||
text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
|
||||
tensor_name, " @", send_device);
|
||||
is_transfer_node = true;
|
||||
} else {
|
||||
text =
|
||||
strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
|
||||
str_util::Join(node->requested_inputs(), ", "), ")");
|
||||
text = strings::StrCat(
|
||||
memory, def.name(), " = ", def.op(), "(",
|
||||
str_util::Join(
|
||||
std::vector<StringPiece>(def.input().begin(), def.input().end()),
|
||||
", "),
|
||||
")");
|
||||
}
|
||||
node_stats->set_timeline_label(text);
|
||||
return is_transfer_node;
|
||||
@ -519,7 +522,7 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) {
|
||||
EdgeInfo* dst_edge = item->output_edge_base();
|
||||
for (auto e : n->out_edges()) {
|
||||
dst_edge->dst_id = e->dst()->id();
|
||||
CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits
|
||||
CHECK_LE(e->src_output(), ((int32)0x3FFFFFFF)); // Must fit in 31 bits
|
||||
dst_edge->output_slot = e->src_output();
|
||||
dst_edge->is_last = false;
|
||||
const int output_slot = dst_edge->output_slot;
|
||||
@ -637,7 +640,7 @@ Status ExecutorImpl::Initialize() {
|
||||
Status s = params_.create_kernel(n->def(), &item->kernel);
|
||||
if (!s.ok()) {
|
||||
item->kernel = nullptr;
|
||||
s = AttachDef(s, *n);
|
||||
s = AttachDef(s, n->def());
|
||||
LOG(ERROR) << "Executor failed to create kernel. " << s;
|
||||
return s;
|
||||
}
|
||||
@ -665,7 +668,7 @@ Status ExecutorImpl::Initialize() {
|
||||
frame_info->nodes->push_back(n);
|
||||
if (IsEnter(n)) {
|
||||
string enter_name;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "frame_name", &enter_name));
|
||||
EnsureFrameInfo(enter_name)->input_count++;
|
||||
}
|
||||
}
|
||||
@ -720,7 +723,7 @@ Status InferAllocAttr(const Node* n, const Node* dst,
|
||||
// so these two cases are not mutually exclusive.
|
||||
if (IsRecv(n)) {
|
||||
string src_name;
|
||||
s = GetNodeAttr(n->attrs(), "send_device", &src_name);
|
||||
s = GetNodeAttr(n->def(), "send_device", &src_name);
|
||||
if (!s.ok()) return s;
|
||||
DeviceNameUtils::ParsedName parsed_src_name;
|
||||
if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
|
||||
@ -745,7 +748,7 @@ Status InferAllocAttr(const Node* n, const Node* dst,
|
||||
}
|
||||
if (IsSend(dst)) {
|
||||
string dst_name;
|
||||
s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
|
||||
s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
|
||||
if (!s.ok()) return s;
|
||||
DeviceNameUtils::ParsedName parsed_dst_name;
|
||||
if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
|
||||
@ -1358,7 +1361,7 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
|
||||
if (IsEnter(curr_node)) {
|
||||
// Enter a child frame.
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
|
||||
GetNodeAttr(curr_node->def(), "frame_name", &frame_name));
|
||||
parent = curr_node;
|
||||
} else if (IsExit(curr_node)) {
|
||||
// Exit to the parent frame.
|
||||
@ -1552,7 +1555,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
|
||||
<< SummarizeNode(*node) << " is dead: " << tagged_node.is_dead;
|
||||
<< SummarizeNodeDef(node->def())
|
||||
<< " is dead: " << tagged_node.is_dead;
|
||||
}
|
||||
|
||||
Entry* input_tensors = GetInputTensors(input_frame, input_iter);
|
||||
@ -1606,7 +1610,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << this << " Async kernel done: "
|
||||
<< SummarizeNode(*state->item->node);
|
||||
<< SummarizeNodeDef(state->item->node->def());
|
||||
}
|
||||
if (stats) nodestats::SetOpEnd(stats);
|
||||
EntryVector outputs;
|
||||
@ -1807,7 +1811,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
|
||||
// tensor value at i-th output.
|
||||
if (!IsSwitch(node) && !IsRecv(node)) {
|
||||
s.Update(errors::Internal("Missing ", i, "-th output from ",
|
||||
SummarizeNode(*node)));
|
||||
SummarizeNodeDef(node->def())));
|
||||
}
|
||||
} else {
|
||||
Entry* out = &((*outputs)[i]);
|
||||
@ -1874,7 +1878,7 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
|
||||
DataTypeString(dtype),
|
||||
" does not match declared output type ",
|
||||
DataTypeString(item.output_type(i)),
|
||||
" for node ", SummarizeNode(*node)));
|
||||
" for node ", SummarizeNodeDef(node->def())));
|
||||
}
|
||||
}
|
||||
if (!val.is_ref()) {
|
||||
@ -1911,7 +1915,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
&impl_->gview_, input_iter, ready);
|
||||
} else if (item->is_enter) {
|
||||
bool is_constant;
|
||||
Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
|
||||
Status s = GetNodeAttr(node->def(), "is_constant", &is_constant);
|
||||
DCHECK(s.ok()) << s;
|
||||
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
|
||||
output_iter = 0;
|
||||
@ -2237,7 +2241,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
FrameState** child) {
|
||||
// Get the child frame name.
|
||||
string enter_name;
|
||||
Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name);
|
||||
Status s = GetNodeAttr(node->def(), "frame_name", &enter_name);
|
||||
DCHECK(s.ok()) << s;
|
||||
const string child_name = MakeFrameName(frame, iter, enter_name);
|
||||
|
||||
@ -2255,7 +2259,7 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
if (vlog_) VLOG(2) << "Create frame: " << child_name;
|
||||
|
||||
int parallel_iters;
|
||||
s = GetNodeAttr(node->attrs(), "parallel_iterations", ¶llel_iters);
|
||||
s = GetNodeAttr(node->def(), "parallel_iterations", ¶llel_iters);
|
||||
DCHECK(s.ok()) << s;
|
||||
FrameState* temp = new FrameState(impl_, parallel_iters);
|
||||
temp->frame_name = child_name;
|
||||
|
@ -150,7 +150,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
|
||||
~FunctionLibraryRuntimeImpl() override;
|
||||
|
||||
Status Instantiate(const string& function_name, AttrSlice attrs,
|
||||
Status Instantiate(const string& function_name,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
Handle* handle) override;
|
||||
|
||||
const FunctionBody* GetFunctionBody(Handle handle) override;
|
||||
@ -207,7 +208,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
};
|
||||
std::vector<Item*> items_;
|
||||
|
||||
Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
|
||||
Status FunctionDefToBody(const FunctionDef& fdef,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
FunctionBody** fbody);
|
||||
Status CreateItem(Handle handle, Item** item);
|
||||
Status GetOrCreateItem(Handle handle, Item** item);
|
||||
@ -322,7 +324,7 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
||||
// Try to instantiate this function for the func/attr. Maybe its
|
||||
// cached already.
|
||||
Handle handle;
|
||||
TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
|
||||
TF_RETURN_IF_ERROR(Instantiate(ndef.op(), ndef.attr(), &handle));
|
||||
|
||||
const FunctionBody* fbody = GetFunctionBody(handle);
|
||||
CHECK_NOTNULL(fbody);
|
||||
@ -353,9 +355,9 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
|
||||
return s;
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::FunctionDefToBody(const FunctionDef& fdef,
|
||||
AttrSlice attrs,
|
||||
FunctionBody** fbody) {
|
||||
Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
|
||||
const FunctionDef& fdef, const InstantiateAttrValueMap& attrs,
|
||||
FunctionBody** fbody) {
|
||||
// Instantiates the function template into a graph def.
|
||||
InstantiationResult result;
|
||||
TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig_, &result));
|
||||
@ -388,13 +390,11 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
|
||||
// TODO(josh11b): Should filter out the attrs from func that aren't used
|
||||
// by the gradient function.
|
||||
TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), g_body));
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBody(grad_fdef, func.attr(), g_body));
|
||||
} else {
|
||||
// f is a user-defined function.
|
||||
Handle f_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle));
|
||||
TF_RETURN_IF_ERROR(Instantiate(func.name(), func.attr(), &f_handle));
|
||||
const FunctionBody* f_body = GetFunctionBody(f_handle);
|
||||
CHECK_NOTNULL(f_body);
|
||||
*g_body = SymbolicGradient(*f_body);
|
||||
@ -402,9 +402,9 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
|
||||
AttrSlice attrs,
|
||||
Handle* handle) {
|
||||
Status FunctionLibraryRuntimeImpl::Instantiate(
|
||||
const string& function_name, const InstantiateAttrValueMap& attrs,
|
||||
Handle* handle) {
|
||||
const string key = Canonicalize(function_name, attrs);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
@ -417,7 +417,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
|
||||
Status s;
|
||||
FunctionBody* fbody = nullptr;
|
||||
if (function_name == kGradientOp) {
|
||||
const AttrValue* f = attrs.Find(kFuncAttr);
|
||||
const AttrValue* f = gtl::FindOrNull(attrs, kFuncAttr);
|
||||
if (f == nullptr) {
|
||||
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
|
||||
}
|
||||
@ -427,7 +427,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
|
||||
}
|
||||
const string grad = lib_def_->FindGradient(func.name());
|
||||
if (!grad.empty()) {
|
||||
return Instantiate(grad, AttrSlice(&func.attr()), handle);
|
||||
return Instantiate(grad, func.attr(), handle);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody));
|
||||
} else {
|
||||
@ -989,12 +989,13 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
|
||||
for (Node* node : graph->nodes()) {
|
||||
VLOG(3) << "Expanding " << node->DebugString();
|
||||
bool noinline;
|
||||
if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
|
||||
if (fld->GetAttr(node->def(), kNoInlineAttr, &noinline).ok() && noinline) {
|
||||
VLOG(3) << "noinline: " << node->DebugString();
|
||||
continue;
|
||||
}
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status s = lib->Instantiate(node->type_string(), node->attrs(), &handle);
|
||||
Status s =
|
||||
lib->Instantiate(node->type_string(), node->def().attr(), &handle);
|
||||
if (!s.ok()) {
|
||||
// Either "node" is a primitive op, or the instantiation failed.
|
||||
if (errors::IsNotFound(s)) {
|
||||
@ -1102,7 +1103,7 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
|
||||
continue;
|
||||
}
|
||||
int index;
|
||||
TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "index", &index));
|
||||
CHECK_LE(0, index);
|
||||
CHECK_LT(index, node_vec->size());
|
||||
(*node_vec)[index] = n;
|
||||
|
@ -40,7 +40,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
typedef FunctionDefHelper FDH;
|
||||
|
||||
@ -59,29 +58,13 @@ void HasError(const Status& s, const string& substr) {
|
||||
<< s << ", expected substring " << substr;
|
||||
}
|
||||
|
||||
// A helper class to make AttrSlice from initializer lists
|
||||
class Attrs {
|
||||
public:
|
||||
Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
|
||||
std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) {
|
||||
for (const auto& aval : attrs) {
|
||||
map_.insert({aval.first, aval.second.proto});
|
||||
}
|
||||
}
|
||||
|
||||
operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
|
||||
|
||||
private:
|
||||
AttrValueMap map_;
|
||||
};
|
||||
|
||||
class FunctionTest : public ::testing::Test {
|
||||
protected:
|
||||
FunctionTest()
|
||||
: device_(DeviceFactory::NewDevice("CPU", {},
|
||||
"/job:localhost/replica:0/task:0")) {}
|
||||
|
||||
void Create(const FunctionDef& fdef, Attrs attrs) {
|
||||
void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) {
|
||||
exec_ = nullptr;
|
||||
InstantiationResult result;
|
||||
TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result));
|
||||
@ -168,8 +151,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
fdef_lib_ = lib_def_->ToProto();
|
||||
}
|
||||
|
||||
Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args,
|
||||
std::vector<Tensor*> rets) {
|
||||
Status Run(const string& name, InstantiateAttrValueSlice attrs,
|
||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = lib_->Instantiate(name, attrs, &handle);
|
||||
if (!status.ok()) {
|
||||
@ -205,7 +188,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) {
|
||||
std::unique_ptr<Graph> GetFuncBody(const string& name,
|
||||
InstantiateAttrValueSlice attrs) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = lib_->Instantiate(name, attrs, &handle);
|
||||
if (!status.ok()) {
|
||||
@ -219,7 +203,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) {
|
||||
std::unique_ptr<Graph> GetGradBody(const string& func,
|
||||
InstantiateAttrValueSlice attrs) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
Status status = lib_->Instantiate(func, attrs, &handle);
|
||||
if (!status.ok()) {
|
||||
@ -630,14 +615,13 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
|
||||
|
||||
// Instantiating "XTimesTwo" should fail.
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle),
|
||||
HasError(lib_->Instantiate("XTimesTwo", {{"T", DT_FLOAT}}, &handle),
|
||||
"Not found: type attr not found");
|
||||
|
||||
// But XTimesFour and XTimes16 instantiation should succeed. Only
|
||||
// when they run, they fail because XTimesTwo is bad.
|
||||
TF_CHECK_OK(
|
||||
lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle));
|
||||
TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle));
|
||||
TF_CHECK_OK(lib_->Instantiate("XTimesFour", {{"T", DT_FLOAT}}, &handle));
|
||||
TF_CHECK_OK(lib_->Instantiate("XTimes16", {{"T", DT_FLOAT}}, &handle));
|
||||
|
||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||
Tensor y;
|
||||
@ -944,7 +928,8 @@ bool DoNothing(Graph* g) { return false; }
|
||||
GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
|
||||
const FunctionDef& fdef) {
|
||||
InstantiationResult result;
|
||||
TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
||||
InstantiateAttrValueMap empty;
|
||||
TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result));
|
||||
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
@ -1263,5 +1248,4 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
||||
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
} // end namespace tensorflow
|
||||
|
@ -103,13 +103,13 @@ void Benchmark::Run(int iters) { RunWithArgs({}, {}, iters); }
|
||||
|
||||
string GetRendezvousKey(const Node* node) {
|
||||
string send_device;
|
||||
TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device));
|
||||
TF_CHECK_OK(GetNodeAttr(node->def(), "send_device", &send_device));
|
||||
string recv_device;
|
||||
TF_CHECK_OK(GetNodeAttr(node->attrs(), "recv_device", &recv_device));
|
||||
TF_CHECK_OK(GetNodeAttr(node->def(), "recv_device", &recv_device));
|
||||
string tensor_name;
|
||||
TF_CHECK_OK(GetNodeAttr(node->attrs(), "tensor_name", &tensor_name));
|
||||
TF_CHECK_OK(GetNodeAttr(node->def(), "tensor_name", &tensor_name));
|
||||
uint64 send_device_incarnation;
|
||||
TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device_incarnation",
|
||||
TF_CHECK_OK(GetNodeAttr(node->def(), "send_device_incarnation",
|
||||
reinterpret_cast<int64*>(&send_device_incarnation)));
|
||||
return Rendezvous::CreateKey(send_device, send_device_incarnation,
|
||||
recv_device, tensor_name, FrameAndIter(0, 0));
|
||||
|
@ -49,11 +49,11 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
|
||||
}
|
||||
}
|
||||
for (Node* n : matches) {
|
||||
AttrSlice n_attrs = n->attrs();
|
||||
AttrSlice n_attrs(n->def());
|
||||
auto base_make_node = [n, g, &n_attrs](const string& op,
|
||||
const string& name) {
|
||||
NodeBuilder node_builder(name, op);
|
||||
node_builder.Device(n->requested_device());
|
||||
node_builder.Device(n->def().device());
|
||||
string colo;
|
||||
if (GetNodeAttr(n_attrs, "_class", &colo).ok()) {
|
||||
node_builder.Attr("_class", colo);
|
||||
|
@ -55,7 +55,7 @@ class ResourceVariableReadPass : public GraphOptimizationPass {
|
||||
}
|
||||
for (Node* read : matches) {
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(read->attrs(), "dtype", &dtype));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(read->def()), "dtype", &dtype));
|
||||
std::vector<Node*> in_control_edges;
|
||||
std::vector<std::pair<Node*, int>> in_edges;
|
||||
for (const Edge* edge : read->in_edges()) {
|
||||
|
@ -76,7 +76,7 @@ void ColocationGroups(const Node& node,
|
||||
std::vector<string> class_specs;
|
||||
// TODO(vrv): We should consider adding a GetNodeAttr that returns a
|
||||
// StringPiece, to avoid a copy.
|
||||
if (!GetNodeAttrSimple(node.attrs(), kColocationAttrNameStringPiece,
|
||||
if (!GetNodeAttrSimple(node.def(), kColocationAttrNameStringPiece,
|
||||
&class_specs)) {
|
||||
// No attribute value is equivalent to the empty colocation_group.
|
||||
*colocation_groups = {
|
||||
@ -329,7 +329,7 @@ class ColocationGraph {
|
||||
AddDebugInfo(node_root, &debug_info);
|
||||
|
||||
DeviceNameUtils::ParsedName specified_device_name;
|
||||
if (DeviceNameUtils::ParseFullName(node->requested_device(),
|
||||
if (DeviceNameUtils::ParseFullName(node->def().device(),
|
||||
&specified_device_name) &&
|
||||
specified_device_name == members_[node_root].device_name) {
|
||||
// The specified device and merged set device match, and
|
||||
@ -348,27 +348,27 @@ class ColocationGraph {
|
||||
std::sort(device_names.begin(), device_names.end());
|
||||
|
||||
return errors::InvalidArgument(
|
||||
"Operation was explicitly assigned to ",
|
||||
node->requested_device(), " but available devices are [ ",
|
||||
"Operation was explicitly assigned to ", node->def().device(),
|
||||
" but available devices are [ ",
|
||||
str_util::Join(device_names, ", "), " ]. Make sure ",
|
||||
"the device specification refers to a valid device.");
|
||||
} else if (specified_device_name.has_type) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not satisfy explicit device specification '",
|
||||
node->requested_device(), "' because no supported kernel for ",
|
||||
node->def().device(), "' because no supported kernel for ",
|
||||
specified_device_name.type, " devices is available.",
|
||||
debug_info);
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Could not satisfy explicit device specification '",
|
||||
node->requested_device(), debug_info);
|
||||
node->def().device(), debug_info);
|
||||
}
|
||||
} else {
|
||||
// The specified device may be a valid device but the
|
||||
// merged set device is different, so print both.
|
||||
return errors::InvalidArgument(
|
||||
"Could not satisfy explicit device specification '",
|
||||
node->requested_device(),
|
||||
node->def().device(),
|
||||
"' because the node was colocated with a group of nodes that "
|
||||
"required incompatible device '",
|
||||
DeviceNameUtils::ParsedNameToString(
|
||||
@ -513,7 +513,7 @@ class ColocationGraph {
|
||||
return errors::Internal("Assigned device '", node.assigned_device_name(),
|
||||
"' does not have registered OpKernel support "
|
||||
"for ",
|
||||
node.type_string());
|
||||
node.def().op());
|
||||
} else {
|
||||
// This node has not yet been assigned to a device, so we
|
||||
// calculate any constraints due to the set of registered
|
||||
@ -527,25 +527,25 @@ class ColocationGraph {
|
||||
registered_device_types.insert(d->device_type());
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
"No OpKernel was registered to support Op '", node.type_string(),
|
||||
"No OpKernel was registered to support Op '", node.def().op(),
|
||||
"' with these attrs. Registered devices: [",
|
||||
str_util::Join(registered_device_types, ","),
|
||||
"], Registered kernels:\n",
|
||||
KernelsRegisteredForOp(node.type_string()));
|
||||
KernelsRegisteredForOp(node.def().op()));
|
||||
}
|
||||
|
||||
// If the NodeDef contains a device, then we interpret it as a
|
||||
// (partial) device specification.
|
||||
if (!node.requested_device().empty()) {
|
||||
if (!node.def().device().empty()) {
|
||||
// The user has specified a device in the NodeDef, try to find a
|
||||
// valid device matching their specification in the set of
|
||||
// devices.
|
||||
// NOTE: The full name may specify a device that is not in
|
||||
// n.supported_device_types(), but we check that in AssignDevice().
|
||||
if (!DeviceNameUtils::ParseFullName(node.requested_device(),
|
||||
if (!DeviceNameUtils::ParseFullName(node.def().device(),
|
||||
&member->device_name)) {
|
||||
return errors::InvalidArgument("Malformed device specification '",
|
||||
node.requested_device(), "'");
|
||||
node.def().device(), "'");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -644,7 +644,7 @@ Status SimplePlacer::Run() {
|
||||
continue;
|
||||
}
|
||||
status = colocation_graph.AddNode(*node);
|
||||
if (!status.ok()) return AttachDef(status, *node);
|
||||
if (!status.ok()) return AttachDef(status, node->def());
|
||||
}
|
||||
|
||||
// 2. Enumerate the constraint edges, and use them to update the disjoint
|
||||
@ -707,7 +707,7 @@ Status SimplePlacer::Run() {
|
||||
"be on the same device), but the two nodes "
|
||||
"were assigned two different devices: ",
|
||||
status.error_message()),
|
||||
*node);
|
||||
node->def());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -749,7 +749,7 @@ Status SimplePlacer::Run() {
|
||||
return AttachDef(
|
||||
errors::InvalidArgument("Cannot assign a device for operation '",
|
||||
node->name(), "': ", status.error_message()),
|
||||
*node);
|
||||
node->def());
|
||||
}
|
||||
|
||||
// Returns the first device in sorted devices list so we will always
|
||||
@ -791,7 +791,7 @@ Status SimplePlacer::Run() {
|
||||
return AttachDef(
|
||||
errors::InvalidArgument("Cannot assign a device for operation '",
|
||||
node->name(), "': ", status.error_message()),
|
||||
*node);
|
||||
node->def());
|
||||
}
|
||||
|
||||
string assigned_device = devices[0]->name();
|
||||
|
@ -223,16 +223,19 @@ Status DebugNodeInserter::InsertNodes(
|
||||
void DebugNodeInserter::DeparallelizeWhileLoops(Graph* graph, Device* device) {
|
||||
for (Node* node : graph->nodes()) {
|
||||
if (node->IsEnter()) {
|
||||
const AttrValue* parallel_iterations =
|
||||
node->attrs().Find("parallel_iterations");
|
||||
if (parallel_iterations && parallel_iterations->i() > 1) {
|
||||
LOG(INFO) << "For debugging, tfdbg is changing the "
|
||||
<< "parallel_iterations attribute of the Enter/RefEnter "
|
||||
<< "node \"" << node->name() << "\" on device \""
|
||||
<< device->name() << "\" from " << parallel_iterations->i()
|
||||
<< " to 1. (This does not affect subsequent non-debug "
|
||||
<< "runs.)";
|
||||
node->AddAttr<int64>("parallel_iterations", 1);
|
||||
for (const auto& attr : node->def().attr()) {
|
||||
if (attr.first == "parallel_iterations") {
|
||||
if (attr.second.i() > 1) {
|
||||
LOG(INFO) << "For debugging, tfdbg is changing the "
|
||||
<< "parallel_iterations attribute of the Enter/RefEnter "
|
||||
<< "node \"" << node->name() << "\" on device \""
|
||||
<< device->name() << "\" from " << attr.second.i()
|
||||
<< " to 1. (This does not affect subsequent non-debug "
|
||||
<< "runs.)";
|
||||
node->AddAttr<int64>("parallel_iterations", 1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
const RunState* run_state,
|
||||
SimpleGraphExecutionState* execution_state);
|
||||
|
||||
string DetailText(const Node& node, const NodeExecStats& ns) {
|
||||
string DetailText(const NodeDef& def, const NodeExecStats& ns) {
|
||||
int64 tot = 0;
|
||||
for (auto& no : ns.output()) {
|
||||
tot += no.tensor_description().allocation_description().requested_bytes();
|
||||
@ -197,8 +197,12 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
|
||||
if (tot >= 0.1 * 1048576.0) {
|
||||
bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
|
||||
}
|
||||
return strings::StrCat(bytes, node.name(), " = ", node.type_string(), "(",
|
||||
str_util::Join(node.requested_inputs(), ", "), ")");
|
||||
return strings::StrCat(
|
||||
bytes, def.name(), " = ", def.op(), "(",
|
||||
str_util::Join(
|
||||
std::vector<StringPiece>(def.input().begin(), def.input().end()),
|
||||
", "),
|
||||
")");
|
||||
}
|
||||
|
||||
private:
|
||||
@ -786,7 +790,7 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
|
||||
if (!ns.timeline_label().empty()) {
|
||||
details = ns.timeline_label();
|
||||
} else if (found_node_in_graph) {
|
||||
details = DetailText(*node, ns);
|
||||
details = DetailText(node->def(), ns);
|
||||
} else {
|
||||
// Leave details string empty
|
||||
}
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/function.pb_text.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
@ -45,11 +44,12 @@ namespace {
|
||||
// Otherwise (arg_def is a simple type T), *is_type_list is set to
|
||||
// false, and *dtypes is set to a single element vector, whose only
|
||||
// element is T.
|
||||
Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
|
||||
bool* is_type_list, DataTypeVector* dtypes) {
|
||||
Status ArgNumType(const InstantiateAttrValueMap& attrs,
|
||||
const OpDef::ArgDef& arg_def, bool* is_type_list,
|
||||
DataTypeVector* dtypes) {
|
||||
dtypes->clear();
|
||||
if (!arg_def.type_list_attr().empty()) {
|
||||
const AttrValue* v = attrs.Find(arg_def.type_list_attr());
|
||||
const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_list_attr());
|
||||
if (v == nullptr) {
|
||||
return errors::NotFound("type attr not found: ",
|
||||
arg_def.type_list_attr());
|
||||
@ -64,7 +64,7 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
|
||||
*is_type_list = false;
|
||||
int num = 1;
|
||||
if (!arg_def.number_attr().empty()) {
|
||||
const AttrValue* v = attrs.Find(arg_def.number_attr());
|
||||
const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr());
|
||||
if (v == nullptr) {
|
||||
return errors::NotFound("type attr not found: ", arg_def.type_attr());
|
||||
}
|
||||
@ -77,7 +77,7 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
|
||||
} else if (arg_def.type_attr().empty()) {
|
||||
dtype = DT_INVALID;
|
||||
} else {
|
||||
const AttrValue* v = attrs.Find(arg_def.type_attr());
|
||||
const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr());
|
||||
if (v == nullptr) {
|
||||
return errors::NotFound("type attr not found: ", arg_def.type_attr());
|
||||
}
|
||||
@ -92,17 +92,18 @@ void AddAttr(const string& name, const T& val, NodeDef* ndef) {
|
||||
SetAttrValue(val, &((*ndef->mutable_attr())[name]));
|
||||
}
|
||||
|
||||
Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
|
||||
Status ValidateSignatureWithAttrs(const OpDef& sig,
|
||||
const InstantiateAttrValueMap& attr_values) {
|
||||
// attr_values should specify all attrs defined in fdef.
|
||||
for (const auto& a : sig.attr()) {
|
||||
const AttrValue* v = attr_values.Find(a.name());
|
||||
if (!v) {
|
||||
auto const iter = attr_values.find(a.name());
|
||||
if (iter == attr_values.end()) {
|
||||
return errors::NotFound("Attr ", a.name(), " is not found from ",
|
||||
SummarizeOpDef(sig));
|
||||
}
|
||||
Status status = AttrValueHasType(*v, a.type());
|
||||
Status status = AttrValueHasType(iter->second, a.type());
|
||||
if (!status.ok()) {
|
||||
errors::AppendToMessage(&status, "for attr '", a.name(), "'");
|
||||
errors::AppendToMessage(&status, "for attr '", iter->first, "'");
|
||||
return status;
|
||||
}
|
||||
}
|
||||
@ -145,7 +146,7 @@ class FunctionInstantiationHelper {
|
||||
|
||||
// Builds index for nodes that can be used as node's input arguments.
|
||||
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
|
||||
AttrSlice attr_values) {
|
||||
const InstantiateAttrValueMap& attr_values) {
|
||||
bool is_type_list;
|
||||
DataTypeVector dtypes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -174,7 +175,8 @@ class FunctionInstantiationHelper {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
|
||||
Status BuildNodeOutputIndex(const NodeDef& node,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
const int arg_index) {
|
||||
const OpDef* node_sig = nullptr;
|
||||
TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
|
||||
@ -204,7 +206,8 @@ class FunctionInstantiationHelper {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
|
||||
Status InstantiateNode(const NodeDef& fnode,
|
||||
const InstantiateAttrValueMap& attrs) {
|
||||
const OpDef* fnode_sig = nullptr;
|
||||
TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
|
||||
NodeDef* gnode = AddNode(fnode.name());
|
||||
@ -292,7 +295,7 @@ class FunctionInstantiationHelper {
|
||||
}
|
||||
|
||||
Status AddReturnNode(
|
||||
const OpDef::ArgDef& ret_def, AttrSlice attrs,
|
||||
const OpDef::ArgDef& ret_def, const InstantiateAttrValueMap& attrs,
|
||||
const ::tensorflow::protobuf::Map<string, string>& ret_map,
|
||||
int* ret_index) {
|
||||
auto ret_iter = ret_map.find(ret_def.name());
|
||||
@ -601,7 +604,7 @@ string Print(const GraphDef& gdef) {
|
||||
|
||||
Status AddDefaultAttrs(const string& op,
|
||||
const GetFunctionSignature& get_function,
|
||||
AttrValueMap* attrs) {
|
||||
InstantiateAttrValueMap* attrs) {
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_RETURN_IF_ERROR(get_function(op, &op_def));
|
||||
AttrSlice attr_slice(attrs);
|
||||
@ -617,7 +620,8 @@ Status AddDefaultAttrs(const string& op,
|
||||
|
||||
} // end namespace
|
||||
|
||||
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
Status InstantiateFunction(const FunctionDef& fdef,
|
||||
const InstantiateAttrValueMap& attr_values,
|
||||
GetFunctionSignature get_function,
|
||||
InstantiationResult* result) {
|
||||
VLOG(3) << "Instantiation Function: " << Print(fdef);
|
||||
@ -635,17 +639,19 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
}
|
||||
}
|
||||
|
||||
auto substitute = [attr_values](StringPiece name, AttrValue* val) {
|
||||
if (const AttrValue* v = attr_values.Find(name)) {
|
||||
*val = *v;
|
||||
auto substitute = [&attr_values](const string& name, AttrValue* val) {
|
||||
auto iter = attr_values.find(name);
|
||||
if (iter == attr_values.end()) {
|
||||
return false;
|
||||
} else {
|
||||
*val = iter->second;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Makes a copy of all attrs in fdef and substitutes placeholders.
|
||||
// After this step, every attr is bound to a concrete value.
|
||||
std::vector<AttrValueMap> node_attrs;
|
||||
std::vector<InstantiateAttrValueMap> node_attrs;
|
||||
node_attrs.resize(fdef.node_def_size());
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
for (auto attr : fdef.node_def(i).attr()) {
|
||||
@ -662,7 +668,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
}
|
||||
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
|
||||
s = helper.BuildNodeOutputIndex(fdef.node_def(i), node_attrs[i],
|
||||
result->gdef.node_size() + i);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
||||
@ -671,7 +677,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
}
|
||||
// Emits one gdef.node for each fdef.node_def.
|
||||
for (int i = 0; i < fdef.node_def_size(); ++i) {
|
||||
s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
|
||||
s = helper.InstantiateNode(fdef.node_def(i), node_attrs[i]);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
|
||||
return s;
|
||||
@ -742,7 +748,8 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
|
||||
return true;
|
||||
}
|
||||
|
||||
string Canonicalize(const string& funcname, AttrSlice attrs) {
|
||||
string Canonicalize(const string& funcname,
|
||||
const InstantiateAttrValueMap& attrs) {
|
||||
std::vector<string> entries;
|
||||
entries.reserve(attrs.size());
|
||||
for (auto p : attrs) {
|
||||
@ -946,7 +953,8 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
|
||||
// If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
|
||||
// Foo's attributes.
|
||||
const NameAttrList* forward_func_attrs;
|
||||
if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
|
||||
if (!GetNodeAttr(AttrSlice(&ndef.attr()), kFuncAttr, &forward_func_attrs)
|
||||
.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
const string& func_name = forward_func_attrs->name();
|
||||
@ -973,30 +981,34 @@ FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
|
||||
return lib;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
|
||||
const string& attr, T* value) const {
|
||||
const FunctionDef* fdef = GetAttrImpl(ndef);
|
||||
if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
|
||||
return Status::OK();
|
||||
Status InstantiateFunction(const FunctionDef& fdef,
|
||||
InstantiateAttrValueSlice attr_values,
|
||||
GetFunctionSignature get_function,
|
||||
InstantiationResult* result) {
|
||||
InstantiateAttrValueMap m;
|
||||
for (const auto& aval : attr_values) {
|
||||
m.insert({aval.first, aval.second.proto});
|
||||
}
|
||||
return errors::InvalidArgument("Attr ", attr, " is not defined.");
|
||||
return InstantiateFunction(fdef, m, std::move(get_function), result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
|
||||
T* value) const {
|
||||
return GetAttr(node.def(), attr, value);
|
||||
string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) {
|
||||
InstantiateAttrValueMap m;
|
||||
for (const auto& aval : attrs) {
|
||||
m.insert({aval.first, aval.second.proto});
|
||||
}
|
||||
return Canonicalize(funcname, m);
|
||||
}
|
||||
|
||||
#define GET_ATTR(T) \
|
||||
template Status FunctionLibraryDefinition::GetAttr(const Node&, \
|
||||
const string&, T*) const; \
|
||||
template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
|
||||
const string&, T*) const;
|
||||
GET_ATTR(string)
|
||||
GET_ATTR(bool)
|
||||
#undef GET_ATTR
|
||||
Status FunctionLibraryRuntime::Instantiate(const string& function_name,
|
||||
InstantiateAttrValueSlice attrs,
|
||||
Handle* handle) {
|
||||
InstantiateAttrValueMap m;
|
||||
for (const auto& aval : attrs) {
|
||||
m.insert({aval.first, aval.second.proto});
|
||||
}
|
||||
return Instantiate(function_name, m, handle);
|
||||
}
|
||||
|
||||
void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
|
||||
if (val.size() >= 2 && val[0] == '$') {
|
||||
|
@ -36,7 +36,6 @@ class CancellationManager;
|
||||
class OpKernel;
|
||||
class ResourceMgr;
|
||||
class ScopedStepContainer;
|
||||
class Node;
|
||||
|
||||
// FunctionDefHelper::Create is a convenient helper to construct a
|
||||
// FunctionDef proto.
|
||||
@ -191,6 +190,11 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
|
||||
// InstantiateFunction calls "get_function" to find signatures of other
|
||||
// functions and primitive ops.
|
||||
|
||||
// Placeholders in "fdef" is substituted based on "attr_values" here.
|
||||
typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap;
|
||||
typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
|
||||
InstantiateAttrValueSlice;
|
||||
|
||||
// GetFunctionSignature(func name, opdef) returns OK if the func name is found
|
||||
// and opdef is filled with a pointer to the corresponding signature
|
||||
// (a OpDef proto). Otherwise, returns an error.
|
||||
@ -202,7 +206,12 @@ struct InstantiationResult {
|
||||
DataTypeVector ret_types;
|
||||
GraphDef gdef;
|
||||
};
|
||||
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
Status InstantiateFunction(const FunctionDef& fdef,
|
||||
const InstantiateAttrValueMap& attr_values,
|
||||
GetFunctionSignature get_function,
|
||||
InstantiationResult* result);
|
||||
Status InstantiateFunction(const FunctionDef& fdef,
|
||||
InstantiateAttrValueSlice attr_values,
|
||||
GetFunctionSignature get_function,
|
||||
InstantiationResult* result);
|
||||
|
||||
@ -232,7 +241,9 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
|
||||
// space. But it may be change as the implementation
|
||||
// evolves. Therefore, it should not be persisted or compared across
|
||||
// address spaces.
|
||||
string Canonicalize(const string& funcname, AttrSlice attrs);
|
||||
string Canonicalize(const string& funcname,
|
||||
const InstantiateAttrValueMap& attrs);
|
||||
string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs);
|
||||
|
||||
// Represents a function call frame. I.e., the data structure used to
|
||||
// pass arguments to a function and retrieve its results.
|
||||
@ -319,16 +330,9 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
// Given a node def 'ndef', inspects attributes of the callee
|
||||
// function to derive the attribute 'value' for 'attr'. Returns OK
|
||||
// iff the attribute is given by the function's definition.
|
||||
// TODO(irving): Remove; keep only the const Node& version.
|
||||
template <typename T>
|
||||
Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;
|
||||
|
||||
// Given a node, inspects attributes of the callee function to derive the
|
||||
// attribute 'value' for 'attr'. Returns OK iff the attribute is given by the
|
||||
// function's definition.
|
||||
template <typename T>
|
||||
Status GetAttr(const Node& node, const string& attr, T* value) const;
|
||||
|
||||
// Returns a proto representation of the state of this function library.
|
||||
FunctionDefLibrary ToProto() const;
|
||||
|
||||
@ -371,8 +375,11 @@ class FunctionLibraryRuntime {
|
||||
// Returns OK and fills in "handle" if the instantiation succeeds.
|
||||
// Otherwise returns an error and "handle" is undefined.
|
||||
typedef uint64 Handle;
|
||||
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
|
||||
virtual Status Instantiate(const string& function_name,
|
||||
const InstantiateAttrValueMap& attrs,
|
||||
Handle* handle) = 0;
|
||||
Status Instantiate(const string& function_name,
|
||||
InstantiateAttrValueSlice attrs, Handle* handle);
|
||||
|
||||
// Returns the function body for the instantiated function given its
|
||||
// handle 'h'. Returns nullptr if "h" is not found.
|
||||
@ -499,15 +506,17 @@ bool RegisterOp(const string& op, Creator func);
|
||||
Status GetOpGradientCreator(const string& op, Creator* creator);
|
||||
};
|
||||
|
||||
// Declare explicit instantiations of GetAttr
|
||||
#define GET_ATTR(T) \
|
||||
extern template Status FunctionLibraryDefinition::GetAttr( \
|
||||
const Node&, const string&, T*) const; \
|
||||
extern template Status FunctionLibraryDefinition::GetAttr( \
|
||||
const NodeDef&, const string&, T*) const;
|
||||
GET_ATTR(string)
|
||||
GET_ATTR(bool)
|
||||
#undef GET_ATTR
|
||||
// Implementation details.
|
||||
|
||||
template <typename T>
|
||||
Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
|
||||
const string& attr, T* value) const {
|
||||
const FunctionDef* fdef = GetAttrImpl(ndef);
|
||||
if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument("Attr ", attr, " is not defined.");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -29,24 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// A helper class to make AttrSlice from initializer lists
|
||||
class Attrs {
|
||||
public:
|
||||
Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
|
||||
std::pair<string, FunctionDefHelper::AttrValueWrapper>>
|
||||
attrs) {
|
||||
for (const auto& aval : attrs) {
|
||||
map_.insert({aval.first, aval.second.proto});
|
||||
}
|
||||
}
|
||||
|
||||
operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
|
||||
|
||||
private:
|
||||
AttrValueMap map_;
|
||||
};
|
||||
|
||||
typedef FunctionDefHelper FDH;
|
||||
|
||||
@ -64,6 +46,8 @@ y: A scalar in type T.
|
||||
|
||||
)doc");
|
||||
|
||||
static InstantiateAttrValueMap kNoAttrs;
|
||||
|
||||
TEST(TFunc, SquarePlusOne) {
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
@ -97,8 +81,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
|
||||
// Instantiate one with T=float
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(
|
||||
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(x:float) -> (y:float) {
|
||||
a = Square[T=float](x)
|
||||
@ -143,8 +126,7 @@ ControlDep(x:int32) -> (y:int32) {
|
||||
|
||||
// Instantiate one with T=float
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(
|
||||
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result));
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(x:int32) -> (y:int32) {
|
||||
a = Identity[T=int32](x)
|
||||
@ -189,7 +171,8 @@ BackCompat() -> (y:float) {
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
||||
TF_ASSERT_OK(
|
||||
InstantiateFunction(fdef, InstantiateAttrValueMap{}, GetOpSig, &result));
|
||||
// Should get T=float from Op's default.
|
||||
const char* e2 = R"P(
|
||||
() -> (a:float) {
|
||||
@ -226,7 +209,7 @@ NTimesT(x:float, y:float) -> (z:float) {
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(x:float, y:float) -> (a:float) {
|
||||
a = AddN[N=2, T=float](x, y)
|
||||
@ -289,8 +272,8 @@ AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
|
||||
|
||||
// Instantiate one with T=float
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, Attrs({{"N", 3}, {"T", DT_FLOAT}}),
|
||||
GetOpSig, &result));
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig,
|
||||
&result));
|
||||
const char* e2 = R"P(
|
||||
(x_0:float, x_1:float, x_2:float) -> (y:float) {
|
||||
a = Map[N=3, T=float, U=float, func=Square[T=float]](x_0, x_1, x_2)
|
||||
@ -332,7 +315,7 @@ ControlDeps(x:float) -> () {
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(x:float) -> () {
|
||||
a = One[T=float]() @ x
|
||||
@ -412,7 +395,7 @@ Test(i:float) -> (o:float) {
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(i:float) -> (o:float) {
|
||||
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
||||
@ -484,7 +467,7 @@ MySelect(x:float) -> (z:float) {
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(x:float) -> (z:float) {
|
||||
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
|
||||
@ -505,9 +488,8 @@ TEST(InstantiateErrors, Not_Sufficient_Attrs) {
|
||||
auto fdef =
|
||||
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
|
||||
InstantiationResult result;
|
||||
HasError(
|
||||
InstantiateFunction(fdef, Attrs({{"U", DT_FLOAT}}), GetOpSig, &result),
|
||||
"Attr T is not found from ");
|
||||
HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result),
|
||||
"Attr T is not found from ");
|
||||
}
|
||||
|
||||
#if 0 // TODO(josh11b): Enable this test once having an extra attr is an error.
|
||||
@ -515,7 +497,7 @@ TEST(InstantiateErrors, Too_Many_Attrs) {
|
||||
auto fdef =
|
||||
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, Attrs({{"T", DT_INT32}, {"U", DT_FLOAT}}),
|
||||
HasError(InstantiateFunction(fdef, {{"T", DT_INT32}, {"U", DT_FLOAT}},
|
||||
GetOpSig, &result),
|
||||
"Attr U is not found in ");
|
||||
}
|
||||
@ -526,7 +508,7 @@ TEST(InstantiateErrors, AttrValue_Value_Placeholder) {
|
||||
FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
|
||||
InstantiationResult result;
|
||||
HasError(
|
||||
InstantiateFunction(fdef, Attrs({{"T", "$bad"}}), GetOpSig, &result),
|
||||
InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result),
|
||||
"AttrValue had value with unexpected type 'placeholder'\n\tfor attr 'T'");
|
||||
}
|
||||
|
||||
@ -536,15 +518,14 @@ TEST(InstantiateErrors, Unbounded_Attr) {
|
||||
{{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(
|
||||
InstantiateFunction(fdef, Attrs({{"T", DT_FLOAT}}), GetOpSig, &result),
|
||||
"Failed to bind all placeholders");
|
||||
HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result),
|
||||
"Failed to bind all placeholders");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, DupArgs) {
|
||||
auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Duplicated arg name");
|
||||
}
|
||||
|
||||
@ -555,7 +536,7 @@ TEST(InstantiateErrors, Dup_Node_Names) {
|
||||
{{"y"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Duplicated ret name");
|
||||
}
|
||||
|
||||
@ -566,7 +547,7 @@ TEST(InstantiateErrors, Node_Arg_Notfound) {
|
||||
},
|
||||
{});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"input z is not found");
|
||||
}
|
||||
|
||||
@ -576,7 +557,7 @@ TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
|
||||
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"input x[0] expected type int32 != float, the type of x[0]");
|
||||
}
|
||||
|
||||
@ -587,7 +568,7 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) {
|
||||
{{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"input[2] == '^z', is not found.");
|
||||
}
|
||||
|
||||
@ -598,7 +579,7 @@ TEST(InstantiateErrors, FuncRet_Missing) {
|
||||
},
|
||||
{});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Return y missing");
|
||||
}
|
||||
|
||||
@ -609,7 +590,7 @@ TEST(InstantiateErrors, FuncRet_NotFound) {
|
||||
},
|
||||
{{"y", "z"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Return y -> z is not found");
|
||||
}
|
||||
|
||||
@ -620,7 +601,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) {
|
||||
},
|
||||
{{"z", "x:y:0"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Return y missing");
|
||||
}
|
||||
|
||||
@ -632,7 +613,7 @@ TEST(InstantiateErrors, FuncRet_NameMismatch) {
|
||||
// },
|
||||
// {{"y", "x:y:0"}, {"z", "x:y:0"}});
|
||||
// InstantiationResult result;
|
||||
// HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
// "ret is not found");
|
||||
// }
|
||||
|
||||
@ -642,7 +623,7 @@ TEST(InstantiateErrors, FuncRet_TypeMismatch) {
|
||||
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Invalid ret types y : float vs. double\n\tIn function output y");
|
||||
}
|
||||
|
||||
@ -668,7 +649,7 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
|
||||
},
|
||||
{{"y", "y:output"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"type attr not found: out_types");
|
||||
}
|
||||
|
||||
@ -695,7 +676,7 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
|
||||
},
|
||||
{{"y", "y:output"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Invalid ret types");
|
||||
}
|
||||
|
||||
@ -722,7 +703,7 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
||||
},
|
||||
{{"y", "y:output"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"input unknown is not found");
|
||||
}
|
||||
|
||||
@ -743,7 +724,7 @@ TEST(InstantiateErrors, TooManyInputs) {
|
||||
{{"z", "a:sum:0"}});
|
||||
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Expected input[2] == 'x' to be a control input.");
|
||||
}
|
||||
|
||||
@ -764,7 +745,7 @@ TEST(InstantiateErrors, TooFewInputs) {
|
||||
{{"z", "a:sum:0"}});
|
||||
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Attempt to access beyond input size: 2 >= 2");
|
||||
}
|
||||
|
||||
@ -792,7 +773,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray1) {
|
||||
{{"z", "a:sum:0"}});
|
||||
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Expected input[1] == 'y' to be a control input.");
|
||||
}
|
||||
|
||||
@ -820,7 +801,7 @@ TEST(InstantiateErrors, TooManyInputsFromArray2) {
|
||||
{{"z", "a:sum:0"}});
|
||||
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Input a:output too long for inputs");
|
||||
}
|
||||
|
||||
@ -841,7 +822,7 @@ TEST(InstantiateErrors, TypeMismatch) {
|
||||
{{"z", "a:sum:0"}});
|
||||
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result),
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"input inputs[1] expected type float != int32, the type of y[0]");
|
||||
}
|
||||
|
||||
@ -893,17 +874,17 @@ TEST(FunctionCallFrame, Float_Float_Float) {
|
||||
}
|
||||
|
||||
TEST(Canonicalize, Basic) {
|
||||
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
|
||||
{"transpose_a", false},
|
||||
{"transpose_b", false}})),
|
||||
EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
|
||||
{"transpose_a", false},
|
||||
{"transpose_b", false}}),
|
||||
"MatMul[T=float,transpose_a=false,transpose_b=false]");
|
||||
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_FLOAT},
|
||||
{"transpose_b", false},
|
||||
{"transpose_a", false}})),
|
||||
EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
|
||||
{"transpose_b", false},
|
||||
{"transpose_a", false}}),
|
||||
"MatMul[T=float,transpose_a=false,transpose_b=false]");
|
||||
EXPECT_EQ(Canonicalize("MatMul", Attrs({{"T", DT_DOUBLE},
|
||||
{"transpose_b", true},
|
||||
{"transpose_a", false}})),
|
||||
EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE},
|
||||
{"transpose_b", true},
|
||||
{"transpose_a", false}}),
|
||||
"MatMul[T=double,transpose_a=false,transpose_b=true]");
|
||||
}
|
||||
|
||||
@ -1167,5 +1148,4 @@ TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) {
|
||||
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
} // end namespace tensorflow
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb_text.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
@ -37,23 +36,18 @@ namespace tensorflow {
|
||||
const char* const kColocationAttrName = "_class";
|
||||
const char* const kColocationGroupPrefix = "loc:@";
|
||||
|
||||
AttrSlice::AttrSlice() : ndef_(nullptr) {
|
||||
static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
|
||||
attrs_ = kEmptyAttrValueMap;
|
||||
}
|
||||
|
||||
AttrSlice::AttrSlice(const NodeDef& node_def)
|
||||
: ndef_(&node_def), attrs_(&ndef_->attr()) {}
|
||||
|
||||
AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
|
||||
|
||||
static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
|
||||
string ret;
|
||||
string SummarizeNodeDef(const NodeDef& node_def) {
|
||||
string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
|
||||
|
||||
// We sort the attrs so the output is deterministic.
|
||||
std::vector<string> attr_names;
|
||||
attr_names.reserve(attrs.size());
|
||||
for (const auto& attr : attrs) {
|
||||
attr_names.reserve(node_def.attr().size());
|
||||
for (const auto& attr : node_def.attr()) {
|
||||
attr_names.push_back(attr.first);
|
||||
}
|
||||
std::sort(attr_names.begin(), attr_names.end());
|
||||
@ -61,34 +55,20 @@ static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
|
||||
for (const string& attr_name : attr_names) {
|
||||
if (!first) strings::StrAppend(&ret, ", ");
|
||||
first = false;
|
||||
strings::StrAppend(&ret, attr_name, "=",
|
||||
SummarizeAttrValue(*attrs.Find(attr_name)));
|
||||
auto iter = node_def.attr().find(attr_name);
|
||||
strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second));
|
||||
}
|
||||
|
||||
// Consider the device to be a final attr with name "_device".
|
||||
if (!device.empty()) {
|
||||
if (!node_def.device().empty()) {
|
||||
if (!first) strings::StrAppend(&ret, ", ");
|
||||
first = false;
|
||||
strings::StrAppend(&ret, "_device=\"", device, "\"");
|
||||
strings::StrAppend(&ret, "_device=\"", node_def.device(), "\"");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
string AttrSlice::SummarizeNode() const {
|
||||
return ndef_ ? SummarizeNodeDef(*ndef_)
|
||||
: strings::StrCat(
|
||||
"[", SummarizeAttrsHelper(*this, StringPiece()), "]");
|
||||
}
|
||||
|
||||
string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
|
||||
|
||||
string SummarizeNodeDef(const NodeDef& node_def) {
|
||||
string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
|
||||
strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
|
||||
strings::StrAppend(&ret, "](");
|
||||
|
||||
// Output inputs, including control inputs, verbatim.
|
||||
bool first = true;
|
||||
first = true;
|
||||
for (const string& input : node_def.input()) {
|
||||
if (!first) strings::StrAppend(&ret, ", ");
|
||||
first = false;
|
||||
@ -129,28 +109,12 @@ Status AttrSlice::Find(StringPiece attr_name,
|
||||
// Skip AttachDef for internal attrs since it is a little bit
|
||||
// expensive and it is common for them to correctly not be included
|
||||
// in a NodeDef.
|
||||
if (!attr_name.starts_with("_") && ndef_ != nullptr) {
|
||||
if (!StringPiece(attr_name).starts_with("_") && ndef_) {
|
||||
s = AttachDef(s, *ndef_);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
|
||||
if (size() != other.size()) return false;
|
||||
|
||||
for (const auto& attr : *other.attrs_) {
|
||||
auto iter = attrs_->find(attr.first);
|
||||
if (iter == attrs_->end()) return false;
|
||||
// TODO(irving): Comparing AttrValues by proto is slightly buggy, since
|
||||
// TensorProto is a nonunique representation of Tensor. This bug will go
|
||||
// away once AttrSlice switches over to NodeInfo.
|
||||
iter->second.SerializeToString(&scratch->a);
|
||||
attr.second.SerializeToString(&scratch->b);
|
||||
if (scratch->a != scratch->b) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// The ... is to allow the caller to inject some value validation code. Use
|
||||
// just ; if no additional validation code is needed.
|
||||
#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
|
||||
@ -377,14 +341,14 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
|
||||
if (StringPiece(input).starts_with("^")) {
|
||||
seen_control = true;
|
||||
if (input.find(':') != string::npos) {
|
||||
return errors::InvalidArgument(
|
||||
"Control input '", input,
|
||||
"' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def));
|
||||
return errors::InvalidArgument("Control input '", input,
|
||||
"' must not have ':' in NodeDef: ",
|
||||
SummarizeNodeDef(node_def));
|
||||
}
|
||||
} else if (seen_control) {
|
||||
return errors::InvalidArgument(
|
||||
"Non-control input '", input,
|
||||
"' after control input in NodeDef: ", SummarizeNodeDef(node_def));
|
||||
return errors::InvalidArgument("Non-control input '", input,
|
||||
"' after control input in NodeDef: ",
|
||||
SummarizeNodeDef(node_def));
|
||||
} else {
|
||||
++num_inputs;
|
||||
}
|
||||
@ -394,8 +358,8 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
|
||||
for (const auto& attr : op_def.attr()) {
|
||||
if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
|
||||
return errors::InvalidArgument("OpDef has duplicate attr name '",
|
||||
attr.name(),
|
||||
"': ", SummarizeOpDef(op_def));
|
||||
attr.name(), "': ",
|
||||
SummarizeOpDef(op_def));
|
||||
}
|
||||
}
|
||||
for (const auto& attr : node_def.attr()) {
|
||||
@ -419,9 +383,8 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
|
||||
"with your GraphDef-generating binary.).");
|
||||
}
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ValidateAttrValue(attr.second, *iter->second),
|
||||
"; NodeDef: ", SummarizeNodeDef(node_def), "; ",
|
||||
SummarizeOpDef(op_def));
|
||||
ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ",
|
||||
SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def));
|
||||
// Keep track of which attr names have (not) been found in the NodeDef.
|
||||
op_attrs.erase(iter);
|
||||
}
|
||||
@ -468,9 +431,9 @@ Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
|
||||
} else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
|
||||
*num = 1;
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Argument '", arg_def.name(),
|
||||
"' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
|
||||
return errors::InvalidArgument("Argument '", arg_def.name(),
|
||||
"' incorrectly specified in op definition: ",
|
||||
SummarizeOpDef(op_def));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -502,11 +465,6 @@ Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NameRangesForNode(const Node& node, const OpDef& op_def,
|
||||
NameRangeMap* inputs, NameRangeMap* outputs) {
|
||||
return NameRangesForNode(node.def(), op_def, inputs, outputs);
|
||||
}
|
||||
|
||||
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
|
||||
for (const auto& attr_def : op_def.attr()) {
|
||||
AttrSlice attrs(*node_def);
|
||||
@ -607,8 +565,4 @@ Status AttachDef(const Status& status, const NodeDef& node_def) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status AttachDef(const Status& status, const Node& node) {
|
||||
return AttachDef(status, node.def());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -29,8 +29,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Node;
|
||||
|
||||
// Name of the attribute used to encode node colocation constraints.
|
||||
//
|
||||
// Nodes can be co-located on the same device. Desire for explicit co-location
|
||||
@ -41,9 +39,8 @@ extern const char* const kColocationAttrName;
|
||||
// String prefix applied to the operation name for colocation constraints.
|
||||
extern const char* const kColocationGroupPrefix;
|
||||
|
||||
// Produce a human-readable version of a Node or NodeDef that is more concise
|
||||
// Produce a human-readable version of a NodeDef that is more concise
|
||||
// than a text-format proto.
|
||||
string SummarizeNode(const Node& node);
|
||||
string SummarizeNodeDef(const NodeDef& node_def);
|
||||
|
||||
typedef protobuf::Map<string, AttrValue> AttrValueMap;
|
||||
@ -81,11 +78,8 @@ class AttrSlice {
|
||||
public:
|
||||
AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
|
||||
|
||||
AttrSlice(); // Empty
|
||||
explicit AttrSlice(const AttrValueMap* a);
|
||||
|
||||
int size() const { return attrs_->size(); }
|
||||
|
||||
// Returns the attr with attr_name if found. Otherwise, returns
|
||||
// nullptr.
|
||||
const AttrValue* Find(StringPiece attr_name) const;
|
||||
@ -94,33 +88,6 @@ class AttrSlice {
|
||||
// NotFound status.
|
||||
Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
|
||||
|
||||
// Helper class to avoid allocations in EqualAttrs.
|
||||
// TODO(irving): Will go away once NodeInfo is used.
|
||||
struct Scratch {
|
||||
string a;
|
||||
string b;
|
||||
};
|
||||
|
||||
// Check if all attrs and attr values match. Does not take defaults into
|
||||
// account.
|
||||
//
|
||||
// TODO(irving): There is a bug in this routine inherited from its
|
||||
// OptimizerCSE::EqualAttrs precedecessor. The same tensor attr can be
|
||||
// represented in more than one way as an AttrValue, since TensorProto is
|
||||
// not 1-1. This bug will go away once I replace everything with NodeInfo,
|
||||
// which stores a Tensor object directly. The Scratch object will also go
|
||||
// away.
|
||||
bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
|
||||
|
||||
// If this AttrSlice has an attached NodeDef, summarize it. This is for
|
||||
// error messages only: we intentionally do not provide direct access to the
|
||||
// NodeDef, since it is not always there.
|
||||
string SummarizeNode() const;
|
||||
|
||||
// Iteration over all attrs
|
||||
AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
|
||||
AttrValueMap::const_iterator end() const { return attrs_->end(); }
|
||||
|
||||
private:
|
||||
const NodeDef* ndef_;
|
||||
const AttrValueMap* attrs_;
|
||||
@ -216,12 +183,9 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
|
||||
// corresponding input/output index range. For example,
|
||||
// input "foo" corresponds to input indices
|
||||
// [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
|
||||
// TODO(irving): Remove the NodeDef version; keep only the Node version.
|
||||
typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap;
|
||||
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
|
||||
NameRangeMap* inputs, NameRangeMap* outputs);
|
||||
Status NameRangesForNode(const Node& node, const OpDef& op_def,
|
||||
NameRangeMap* inputs, NameRangeMap* outputs);
|
||||
|
||||
// Adds default values to *node_def for unspecified attrs from op_def.
|
||||
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
|
||||
@ -242,7 +206,6 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
|
||||
// Returns "status" with kernel's NodeDef attached as additional text
|
||||
// in the error message.
|
||||
Status AttachDef(const Status& status, const NodeDef& node_def);
|
||||
Status AttachDef(const Status& status, const Node& node);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
@ -843,10 +842,13 @@ bool InTypeList(DataType dt, const AttrValue& type_list) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
|
||||
// an error if attrs in kernel_def are not found, or have a mismatching type.
|
||||
Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
|
||||
// Returns whether the attrs in the NodeDef satisfy the constraints in
|
||||
// the kernel_def. Returns an error if attrs in kernel_def are not
|
||||
// found, or have a mismatching type.
|
||||
Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def,
|
||||
bool* match) {
|
||||
*match = false;
|
||||
AttrSlice attrs(node_def);
|
||||
for (const auto& constraint : kernel_def.constraint()) {
|
||||
if (constraint.allowed_values().list().type_size() == 0) {
|
||||
return errors::Unimplemented(
|
||||
@ -870,7 +872,7 @@ Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
|
||||
"' that has value '", SummarizeAttrValue(*found),
|
||||
"' that does not have type 'type' or 'list(type)' in NodeDef "
|
||||
"'",
|
||||
attrs.SummarizeNode(), "'");
|
||||
SummarizeNodeDef(node_def), "'");
|
||||
}
|
||||
|
||||
for (int t : found->list().type()) {
|
||||
@ -883,7 +885,7 @@ Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
|
||||
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
|
||||
constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def),
|
||||
"', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
|
||||
}
|
||||
}
|
||||
@ -893,7 +895,6 @@ Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
|
||||
|
||||
static const StringPiece kKernelAttr("_kernel");
|
||||
|
||||
// TODO(irving): Replace with const Node& version below.
|
||||
Status FindKernelRegistration(const DeviceType& device_type,
|
||||
const NodeDef& node_def,
|
||||
const KernelRegistration** reg,
|
||||
@ -926,16 +927,8 @@ Status FindKernelRegistration(const DeviceType& device_type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FindKernelRegistration(const DeviceType& device_type, const Node& node,
|
||||
const KernelRegistration** reg,
|
||||
bool* was_attr_mismatch) {
|
||||
return FindKernelRegistration(device_type, node.def(), reg,
|
||||
was_attr_mismatch);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(irving): Change const NodeDef& to const Node&
|
||||
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
|
||||
const KernelDef** def, string* kernel_class_name) {
|
||||
const KernelRegistration* reg = nullptr;
|
||||
|
@ -184,8 +184,8 @@ class InferenceContext {
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < num_outputs(); ++i) {
|
||||
DCHECK(output(i).IsSet())
|
||||
<< i << " for " << node_def_.name() << " of type " << node_def_.op();
|
||||
DCHECK(output(i).IsSet()) << i << " for " << node_def().name()
|
||||
<< " of type " << node_def().op();
|
||||
}
|
||||
#endif // NDEBUG
|
||||
return s;
|
||||
@ -394,6 +394,11 @@ class InferenceContext {
|
||||
// the value.
|
||||
Status MakeDimForScalarInput(int idx, DimensionHandle* out);
|
||||
|
||||
// Returns the NodeDef. The returned reference does not outlive the
|
||||
// InferenceContext, and it should not be used after InferenceContext is
|
||||
// destroyed.
|
||||
const NodeDef& node_def() { return node_def_; }
|
||||
|
||||
// Look up the attr for the NodeDef being evaluated with name attr_name and
|
||||
// set *value to its value. If no attr with attr_name is found in def(), or
|
||||
// the attr does not have a matching type, a non-ok status will be returned.
|
||||
|
@ -88,7 +88,7 @@ Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
|
||||
out_info->frame = out;
|
||||
out_info->parent_frame = frame;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name));
|
||||
GetNodeAttr(out->def(), "frame_name", &out_info->frame_name));
|
||||
if (out_info->frame_name.empty()) {
|
||||
return errors::InvalidArgument("The Enter node ", out->name(),
|
||||
" must have a frame name.");
|
||||
|
@ -78,7 +78,7 @@ string Node::DebugString() const {
|
||||
} else {
|
||||
strings::StrAppend(&ret, " op device:");
|
||||
strings::StrAppend(&ret, "{", assigned_device_name_, "}");
|
||||
strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
|
||||
strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
@ -474,7 +474,7 @@ void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const Edge* edge = inputs[i];
|
||||
if (edge == nullptr) {
|
||||
node_def->add_input(node->requested_inputs()[i]);
|
||||
node_def->add_input(node->def().input(i));
|
||||
} else {
|
||||
const Node* src = edge->src();
|
||||
if (!src->IsOp()) continue;
|
||||
|
@ -71,7 +71,6 @@ class Node {
|
||||
int cost_id() const { return cost_id_; }
|
||||
const string& name() const { return props_->node_def_.name(); }
|
||||
const string& type_string() const { return props_->node_def_.op(); }
|
||||
|
||||
// def() provides the NodeDef the user supplied, but the specifics
|
||||
// of this Node may have changed due to placement, optimization, etc.
|
||||
// In particular:
|
||||
@ -81,7 +80,6 @@ class Node {
|
||||
// * def().device() is the "user's requested device" and may not match
|
||||
// the actual assigned device, see assigned_device_name() below;
|
||||
// * def().attr() is authoritative.
|
||||
// TODO(irving): Replace with NodeInfo.
|
||||
const NodeDef& def() const { return props_->node_def_; }
|
||||
const OpDef& op_def() const { return *props_->op_def_; }
|
||||
|
||||
@ -94,10 +92,6 @@ class Node {
|
||||
DataType output_type(int32 o) const { return props_->output_types_[o]; }
|
||||
const DataTypeVector& output_types() const { return props_->output_types_; }
|
||||
|
||||
// The device requested by the user. For the actual assigned device,
|
||||
// use assigned_device_name() below.
|
||||
const string& requested_device() const { return def().device(); }
|
||||
|
||||
// This gives the device the runtime has assigned this node to. If
|
||||
// you want the device the user requested, use def().device() instead.
|
||||
// TODO(josh11b): Validate that the assigned_device, if not empty:
|
||||
@ -109,14 +103,6 @@ class Node {
|
||||
assigned_device_name_ = device_name;
|
||||
}
|
||||
|
||||
// Read only access to attributes
|
||||
AttrSlice attrs() const { return AttrSlice(def()); }
|
||||
|
||||
// Inputs requested by the NodeDef. For the actual inputs, use in_edges.
|
||||
const protobuf::RepeatedPtrField<string>& requested_inputs() const {
|
||||
return def().input();
|
||||
}
|
||||
|
||||
// Get the neighboring nodes via edges either in or out of this node.
|
||||
gtl::iterator_range<NeighborIter> in_nodes() const;
|
||||
gtl::iterator_range<NeighborIter> out_nodes() const;
|
||||
|
@ -424,7 +424,7 @@ Status GraphConstructor::ValidateShape(Node* node) {
|
||||
// For nodes with the _output_shapes atttribute, override the shape.
|
||||
std::vector<TensorShapeProto> shape_attrs;
|
||||
const char* kAttrName = "_output_shapes";
|
||||
if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) {
|
||||
if (!GetNodeAttr(node->def(), kAttrName, &shape_attrs).ok()) {
|
||||
// No _output_shapes attribute, the AddNode call above was sufficient.
|
||||
return Status::OK();
|
||||
}
|
||||
@ -458,7 +458,7 @@ Status GraphConstructor::ValidateShape(Node* node) {
|
||||
// functions that are not critical to correct execution but
|
||||
// would cause graphs to fail if imported after correcting.
|
||||
//
|
||||
const string& op = node->type_string();
|
||||
const string& op = node->def().op();
|
||||
const std::vector<string> whitelist = {
|
||||
// To be removed after 2017/03/08.
|
||||
"RandomShuffleQueue", "PaddingFIFOQueue", "FIFOQueue",
|
||||
|
@ -146,7 +146,7 @@ class GraphConstructorTest : public ::testing::Test {
|
||||
return "";
|
||||
}
|
||||
std::vector<string> value;
|
||||
Status s = GetNodeAttr(n->attrs(), kColocationAttrName, &value);
|
||||
Status s = GetNodeAttr(n->def(), kColocationAttrName, &value);
|
||||
if (!s.ok()) {
|
||||
return "";
|
||||
}
|
||||
@ -997,7 +997,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_DefaultAttrs) {
|
||||
}
|
||||
ASSERT_TRUE(a != nullptr);
|
||||
int value = 0;
|
||||
s = GetNodeAttr(a->attrs(), "default_int", &value);
|
||||
s = GetNodeAttr(a->def(), "default_int", &value);
|
||||
ASSERT_EQ(Status::OK(), s) << s << " -- " << a->def().DebugString();
|
||||
EXPECT_EQ(31415, value);
|
||||
}
|
||||
@ -1201,9 +1201,9 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) {
|
||||
|
||||
// Check that t1's NodeDef is consistent with graph
|
||||
Node* t1 = FindNode("t1");
|
||||
ASSERT_EQ(t1->requested_inputs().size(), 2);
|
||||
ASSERT_EQ(t1->requested_inputs()[0], "input:1");
|
||||
ASSERT_EQ(t1->requested_inputs()[1], "input:0");
|
||||
ASSERT_EQ(t1->def().input_size(), 2);
|
||||
ASSERT_EQ(t1->def().input(0), "input:1");
|
||||
ASSERT_EQ(t1->def().input(1), "input:0");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) {
|
||||
@ -1254,19 +1254,19 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) {
|
||||
|
||||
// Check that NodeDefs are consistent with graph
|
||||
Node* t1 = FindNode("import/t1");
|
||||
ASSERT_EQ(t1->requested_inputs().size(), 2);
|
||||
EXPECT_EQ(t1->requested_inputs()[0], "input:0");
|
||||
EXPECT_EQ(t1->requested_inputs()[1], "input:0");
|
||||
ASSERT_EQ(t1->def().input_size(), 2);
|
||||
EXPECT_EQ(t1->def().input(0), "input:0");
|
||||
EXPECT_EQ(t1->def().input(1), "input:0");
|
||||
|
||||
Node* t2 = FindNode("import/t2");
|
||||
ASSERT_EQ(t2->requested_inputs().size(), 2);
|
||||
EXPECT_EQ(t2->requested_inputs()[0], "import/t1:0");
|
||||
EXPECT_EQ(t2->requested_inputs()[1], "import/t1:0");
|
||||
ASSERT_EQ(t2->def().input_size(), 2);
|
||||
EXPECT_EQ(t2->def().input(0), "import/t1:0");
|
||||
EXPECT_EQ(t2->def().input(1), "import/t1:0");
|
||||
|
||||
Node* t3 = FindNode("import/t3");
|
||||
ASSERT_EQ(t3->requested_inputs().size(), 2);
|
||||
EXPECT_EQ(t3->requested_inputs()[0], "import/unmapped_input:0");
|
||||
EXPECT_EQ(t3->requested_inputs()[1], "import/unmapped_input:1");
|
||||
ASSERT_EQ(t3->def().input_size(), 2);
|
||||
EXPECT_EQ(t3->def().input(0), "import/unmapped_input:0");
|
||||
EXPECT_EQ(t3->def().input(1), "import/unmapped_input:1");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
|
||||
@ -1795,24 +1795,24 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) {
|
||||
|
||||
// Test that node defs are consistent with graph
|
||||
Node* w1 = FindNode("import/W1");
|
||||
ASSERT_EQ(w1->requested_inputs().size(), 2);
|
||||
EXPECT_EQ(w1->requested_inputs()[0], "^W1");
|
||||
EXPECT_EQ(w1->requested_inputs()[1], "^W2");
|
||||
ASSERT_EQ(w1->def().input_size(), 2);
|
||||
EXPECT_EQ(w1->def().input(0), "^W1");
|
||||
EXPECT_EQ(w1->def().input(1), "^W2");
|
||||
|
||||
Node* input = FindNode("import/input");
|
||||
ASSERT_EQ(input->requested_inputs().size(), 2);
|
||||
EXPECT_EQ(input->requested_inputs()[0], "^W1");
|
||||
EXPECT_EQ(input->requested_inputs()[1], "^W2");
|
||||
ASSERT_EQ(input->def().input_size(), 2);
|
||||
EXPECT_EQ(input->def().input(0), "^W1");
|
||||
EXPECT_EQ(input->def().input(1), "^W2");
|
||||
|
||||
Node* input2 = FindNode("import/input2");
|
||||
ASSERT_EQ(input2->requested_inputs().size(), 2);
|
||||
EXPECT_EQ(input2->requested_inputs()[0], "^W1");
|
||||
EXPECT_EQ(input2->requested_inputs()[1], "^W2");
|
||||
ASSERT_EQ(input2->def().input_size(), 2);
|
||||
EXPECT_EQ(input2->def().input(0), "^W1");
|
||||
EXPECT_EQ(input2->def().input(1), "^W2");
|
||||
|
||||
Node* t1 = FindNode("import/t1");
|
||||
ASSERT_EQ(t1->requested_inputs().size(), 2);
|
||||
EXPECT_EQ(t1->requested_inputs()[0], "import/input:0");
|
||||
EXPECT_EQ(t1->requested_inputs()[1], "import/input:1");
|
||||
ASSERT_EQ(t1->def().input_size(), 2);
|
||||
EXPECT_EQ(t1->def().input(0), "import/input:0");
|
||||
EXPECT_EQ(t1->def().input(1), "import/input:1");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
|
||||
@ -1856,15 +1856,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
|
||||
|
||||
// Test that node defs are consistent with graph
|
||||
Node* merge = FindNode("merge");
|
||||
ASSERT_EQ(merge->requested_inputs().size(), 3);
|
||||
EXPECT_EQ(merge->requested_inputs()[0], "input:0");
|
||||
EXPECT_EQ(merge->requested_inputs()[1], "t1:0");
|
||||
EXPECT_EQ(merge->requested_inputs()[2], "^W1");
|
||||
ASSERT_EQ(merge->def().input_size(), 3);
|
||||
EXPECT_EQ(merge->def().input(0), "input:0");
|
||||
EXPECT_EQ(merge->def().input(1), "t1:0");
|
||||
EXPECT_EQ(merge->def().input(2), "^W1");
|
||||
|
||||
Node* t1 = FindNode("t1");
|
||||
ASSERT_EQ(t1->requested_inputs().size(), 2);
|
||||
EXPECT_EQ(t1->requested_inputs()[0], "merge:0");
|
||||
EXPECT_EQ(t1->requested_inputs()[1], "merge:0");
|
||||
ASSERT_EQ(t1->def().input_size(), 2);
|
||||
EXPECT_EQ(t1->def().input(0), "merge:0");
|
||||
EXPECT_EQ(t1->def().input(1), "merge:0");
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) {
|
||||
|
@ -356,7 +356,7 @@ string ControlLoopName(const string& name) {
|
||||
}
|
||||
|
||||
bool IsControlLoop(const Node* node) {
|
||||
const string& name = node->name();
|
||||
const string& name = node->def().name();
|
||||
return StringPiece(name).starts_with("_cloop");
|
||||
}
|
||||
|
||||
@ -468,7 +468,7 @@ Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
|
||||
const string& device_name = edge->dst()->assigned_device_name();
|
||||
const string& frame_name = src_info.frame_name;
|
||||
int parallel_iterations;
|
||||
status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
|
||||
status = GetNodeAttr(src_info.frame->def(), "parallel_iterations",
|
||||
¶llel_iterations);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
@ -903,11 +903,11 @@ Status Partition(const PartitionOptions& opts, Graph* g,
|
||||
send_start_time = opts.start_times[src->id()].value();
|
||||
recv_start_time = opts.start_times[dst->id()].value();
|
||||
} else {
|
||||
status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
|
||||
status = GetNodeAttr(src->def(), "_start_time", &send_start_time);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
|
||||
status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
@ -318,21 +318,21 @@ TEST_F(GraphTest, AddAttr) {
|
||||
n1->AddAttr("_a", "new_attr");
|
||||
|
||||
string attr;
|
||||
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
|
||||
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr));
|
||||
EXPECT_EQ("new_attr", attr);
|
||||
|
||||
Node* n2 = graph_.CopyNode(n1);
|
||||
|
||||
n1->AddAttr("_b", "new_attr_2");
|
||||
|
||||
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
|
||||
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr));
|
||||
EXPECT_EQ("new_attr", attr);
|
||||
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr));
|
||||
EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_b", &attr));
|
||||
EXPECT_EQ("new_attr_2", attr);
|
||||
|
||||
EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr));
|
||||
EXPECT_EQ(Status::OK(), GetNodeAttr(n2->def(), "_a", &attr));
|
||||
EXPECT_EQ("new_attr", attr);
|
||||
EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr));
|
||||
EXPECT_NE(Status::OK(), GetNodeAttr(n2->def(), "_b", &attr));
|
||||
}
|
||||
|
||||
// Convert edge iteration results into a sorted string.
|
||||
|
@ -56,9 +56,11 @@ class OptimizerCSE {
|
||||
bool Optimize(const std::function<bool(const Node*)>& consider_fn);
|
||||
|
||||
private:
|
||||
struct Scratch;
|
||||
|
||||
static size_t NodeHash(const Node* n);
|
||||
static bool Equivalent(const Node* a, const Node* b,
|
||||
AttrSlice::Scratch* scratch);
|
||||
static bool Equivalent(const Node* a, const Node* b, Scratch* s);
|
||||
static bool EqualAttrs(const Node* a, const Node* b, Scratch* s);
|
||||
|
||||
Graph* g_;
|
||||
};
|
||||
@ -108,7 +110,7 @@ size_t OptimizerCSE::NodeHash(const Node* n) {
|
||||
// Hash the attrs. For example, this makes sure different constants
|
||||
// end up in different hash buckets.
|
||||
string tmp;
|
||||
for (const auto& attr : n->attrs()) {
|
||||
for (const auto& attr : n->def().attr()) {
|
||||
tmp = attr.first;
|
||||
attr.second.AppendToString(&tmp);
|
||||
// Add hashes of attrs, so the order of attrs doesn't matter.
|
||||
@ -120,6 +122,28 @@ size_t OptimizerCSE::NodeHash(const Node* n) {
|
||||
return h;
|
||||
}
|
||||
|
||||
struct OptimizerCSE::Scratch {
|
||||
// For EqualAttrs():
|
||||
string a;
|
||||
string b;
|
||||
};
|
||||
|
||||
bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) {
|
||||
if (a->def().attr_size() != b->def().attr_size()) return false;
|
||||
|
||||
for (const auto& attr : b->def().attr()) {
|
||||
auto iter = a->def().attr().find(attr.first);
|
||||
if (iter == a->def().attr().end()) return false;
|
||||
// Note: it should be safe to compare proto serializations of the attr
|
||||
// values since at most one field should be set in each (indeed, it
|
||||
// should be the same field).
|
||||
iter->second.SerializeToString(&scratch->a);
|
||||
attr.second.SerializeToString(&scratch->b);
|
||||
if (scratch->a != scratch->b) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool HasRefInput(const Node* n) {
|
||||
for (auto dt : n->input_types()) {
|
||||
if (IsRefType(dt)) return true;
|
||||
@ -127,8 +151,7 @@ static bool HasRefInput(const Node* n) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool OptimizerCSE::Equivalent(const Node* a, const Node* b,
|
||||
AttrSlice::Scratch* scratch) {
|
||||
bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) {
|
||||
// Different op names are different
|
||||
if (a->type_string() != b->type_string()) return false;
|
||||
|
||||
@ -141,7 +164,7 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b,
|
||||
|
||||
// Compare attrs. Note that equal attrs implies equal input and
|
||||
// output types.
|
||||
if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false;
|
||||
if (!EqualAttrs(a, b, scratch)) return false;
|
||||
|
||||
// Compare input sources
|
||||
if (a->num_inputs() != b->num_inputs()) return false;
|
||||
@ -183,7 +206,7 @@ bool OptimizerCSE::Optimize(
|
||||
// Scratch space for Equivalent calls. Allocated here and passed in to
|
||||
// Equivalent to avoid allocation inside the loop below.
|
||||
bool changed = false;
|
||||
AttrSlice::Scratch scratch;
|
||||
Scratch scratch;
|
||||
for (Node* n : order) {
|
||||
if (!n->IsOp()) continue;
|
||||
|
||||
|
@ -192,9 +192,9 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
|
||||
Tensor tensor_names;
|
||||
Tensor shape_and_slices;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names));
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices));
|
||||
GetNodeAttr(AttrSlice(tensor_names_op->def()), "value", &tensor_names));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(shape_and_slices_op->def()), "value",
|
||||
&shape_and_slices));
|
||||
|
||||
int tn_size = tensor_names.NumElements();
|
||||
int var_size = added_variables.size();
|
||||
|
@ -112,15 +112,17 @@ TEST_F(QuantizeTrainingTest, SignedInput) {
|
||||
TF_ASSERT_OK(
|
||||
FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"),
|
||||
&identity_q_node));
|
||||
NodeDef identity_q = identity_q_node->def();
|
||||
ASSERT_EQ("true",
|
||||
SummarizeAttrValue(*identity_q_node->attrs().Find("signed_input")));
|
||||
SummarizeAttrValue(identity_q.attr().find("signed_input")->second));
|
||||
// Quantize_and_dequantize node for relu should have signed_input==false.
|
||||
Node* relu_q_node;
|
||||
TF_ASSERT_OK(
|
||||
FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
|
||||
&relu_q_node));
|
||||
NodeDef relu_q = relu_q_node->def();
|
||||
ASSERT_EQ("false",
|
||||
SummarizeAttrValue(*relu_q_node->attrs().Find("signed_input")));
|
||||
SummarizeAttrValue(relu_q.attr().find("signed_input")->second));
|
||||
}
|
||||
|
||||
TEST_F(QuantizeTrainingTest, RangeGivenTrue) {
|
||||
@ -163,15 +165,17 @@ TEST_F(QuantizeTrainingTest, RangeGivenTrue) {
|
||||
TF_ASSERT_OK(
|
||||
FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"),
|
||||
&relu6_q_node));
|
||||
NodeDef identity_q = relu6_q_node->def();
|
||||
ASSERT_EQ("true",
|
||||
SummarizeAttrValue(*relu6_q_node->attrs().Find("range_given")));
|
||||
SummarizeAttrValue(identity_q.attr().find("range_given")->second));
|
||||
// Quantize_and_dequantize node for relu should have range_given==true.
|
||||
Node* relu_q_node;
|
||||
TF_ASSERT_OK(
|
||||
FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
|
||||
&relu_q_node));
|
||||
NodeDef relu_q = relu_q_node->def();
|
||||
ASSERT_EQ("true",
|
||||
SummarizeAttrValue(*relu_q_node->attrs().Find("range_given")));
|
||||
SummarizeAttrValue(relu_q.attr().find("range_given")->second));
|
||||
}
|
||||
|
||||
TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) {
|
||||
|
@ -106,7 +106,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
||||
// Copy the _output_shapes from the original node to the feed node,
|
||||
// if any.
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
if (GetNodeAttr(n->attrs(), "_output_shapes", &output_shapes).ok()) {
|
||||
if (GetNodeAttr(n->def(), "_output_shapes", &output_shapes).ok()) {
|
||||
if (n->num_outputs() != output_shapes.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"FeedInputs: ", t,
|
||||
@ -129,8 +129,8 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
||||
if (e->src_output() == id.second) {
|
||||
to_remove.emplace_back(e);
|
||||
} else if (e->src_output() == Graph::kControlSlot &&
|
||||
(n->type_string() == "Placeholder" ||
|
||||
n->type_string() == "PlaceholderV2")) {
|
||||
(n->def().op() == "Placeholder" ||
|
||||
n->def().op() == "PlaceholderV2")) {
|
||||
// When feeding a Placeholder node, any outgoing control edges
|
||||
// will be replaced with a control edge from the replacement
|
||||
// recv_node.
|
||||
|
@ -81,7 +81,7 @@ class SubgraphTest : public ::testing::Test {
|
||||
for (const string& s : expected_nodes) {
|
||||
Node* n = FindNode(s);
|
||||
EXPECT_TRUE(n != nullptr) << s;
|
||||
if (n->type_string() == "_Send" || n->type_string() == "_Recv") {
|
||||
if (n->def().op() == "_Send" || n->def().op() == "_Recv") {
|
||||
EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s;
|
||||
}
|
||||
}
|
||||
@ -367,7 +367,7 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) {
|
||||
for (Node* node : graph()->nodes()) {
|
||||
if (node->name() == "_recv_input_1") {
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
TF_ASSERT_OK(GetNodeAttr(node->attrs(), "_output_shapes", &shapes));
|
||||
TF_ASSERT_OK(GetNodeAttr(node->def(), "_output_shapes", &shapes));
|
||||
ASSERT_EQ(1, shapes.size());
|
||||
EXPECT_TRUE(PartialTensorShape({23}).IsIdenticalTo(shapes[0]));
|
||||
break;
|
||||
|
@ -48,6 +48,12 @@ class Cluster {
|
||||
// of the requested resources are available.
|
||||
virtual Status Provision() = 0;
|
||||
|
||||
// Attempts to shutdown the cluster.
|
||||
// Returns OK iff there are no pending calls to the Run() method and all the
|
||||
// resources used by the cluster could be released. Returns an error
|
||||
// otherwise.
|
||||
virtual Status Shutdown() { return Status::OK(); }
|
||||
|
||||
// Whether soft placement is allowed. If allow_soft_placement is true,
|
||||
// an op will be placed on CPU if there's no GPU implementation for the OP
|
||||
// or if no GPU devices are known or registered or if we need to co-locate
|
||||
@ -58,7 +64,8 @@ class Cluster {
|
||||
// before Provision().
|
||||
void SetNumWarmupSteps(int num_steps);
|
||||
|
||||
// Disable the collection of detailed statistics.
|
||||
// Disable the collection of detailed statistics. Must be called
|
||||
// before Provision().
|
||||
void DisableDetailedStats(bool disable);
|
||||
|
||||
// Return the list of TensorFlow devices that are available to execute a
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -91,6 +92,31 @@ Status SingleMachine::Initialize(const GrapplerItem& item) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SingleMachine::Shutdown() {
|
||||
TF_RETURN_IF_ERROR(CloseSession(true /*use_timeout*/));
|
||||
|
||||
// Delete the threadpool: this ensures that all the pending closures complete
|
||||
// before we return. Note that if that if TF deadlocked on us, the closures
|
||||
// will never complete, and the call to thread_pool_.reset() will never
|
||||
// return: therefore we need to delete the threadpool with the background
|
||||
// thread. That thread itself will also never complete, so the user should
|
||||
// abort the process to avoid leaking too many resources.
|
||||
auto n = std::make_shared<Notification>();
|
||||
Env::Default()->SchedClosure([this, n]() {
|
||||
thread_pool_.reset();
|
||||
n->Notify();
|
||||
});
|
||||
int64 timeout_us = 1000000ll * timeout_s_;
|
||||
const bool notified = WaitForNotificationWithTimeout(n.get(), timeout_us);
|
||||
if (!notified) {
|
||||
// Let the caller know that we can't shutdown the session properly since
|
||||
// there are calls to Session::Run() still running.
|
||||
return errors::Unavailable("The session is still running graphs after ",
|
||||
timeout_s_, " seconds");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SingleMachine::Run(const GraphDef& graph_def,
|
||||
const std::vector<std::pair<string, Tensor>>& feed,
|
||||
const std::vector<string>& fetch,
|
||||
@ -163,10 +189,11 @@ Status SingleMachine::RunWithTimeout(
|
||||
mutex_lock l(close_mu_);
|
||||
CHECK(!closing_);
|
||||
}
|
||||
|
||||
auto status = std::make_shared<Status>();
|
||||
auto local_metadata = std::make_shared<RunMetadata>();
|
||||
const bool executed_in_time = ExecuteWithTimeout(
|
||||
[this, status, local_metadata, &feed, &fetch]() {
|
||||
[this, status, local_metadata, feed, fetch]() {
|
||||
*status = session_->Run(run_options_, feed, {}, fetch, nullptr,
|
||||
local_metadata.get());
|
||||
},
|
||||
@ -230,11 +257,7 @@ Status SingleMachine::ResetSession() {
|
||||
LOG(INFO) << "Cleaning up previous session";
|
||||
|
||||
// Make sure the session is properly closed
|
||||
TF_RETURN_IF_ERROR(CloseSession(true /*use_timeout*/));
|
||||
|
||||
// Flush all the pending closures (if any).
|
||||
thread_pool_.reset(new thread::ThreadPool(
|
||||
Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
|
||||
TF_RETURN_IF_ERROR(Shutdown());
|
||||
|
||||
// We need to Reset the session to ensure that all the variables are
|
||||
// deleted. But first we need to delete the session since Reset()
|
||||
@ -245,6 +268,10 @@ Status SingleMachine::ResetSession() {
|
||||
|
||||
LOG(INFO) << "Starting new session";
|
||||
|
||||
// Create a new threadpool
|
||||
thread_pool_.reset(new thread::ThreadPool(
|
||||
Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
|
||||
|
||||
session_.reset(NewSession(options_));
|
||||
CHECK(session_ != nullptr);
|
||||
|
||||
|
@ -33,6 +33,8 @@ class SingleMachine : public Cluster {
|
||||
~SingleMachine() override;
|
||||
|
||||
Status Provision() override;
|
||||
Status Shutdown() override;
|
||||
|
||||
Status Initialize(const GrapplerItem& item) override;
|
||||
Status Run(const GraphDef& item,
|
||||
const std::vector<std::pair<string, Tensor>>& feed,
|
||||
|
@ -122,6 +122,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + if_cuda([
|
||||
"//tensorflow/core:cuda",
|
||||
@ -226,7 +227,6 @@ cc_library(
|
||||
":utils",
|
||||
":virtual_placer",
|
||||
":virtual_scheduler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
||||
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
@ -151,8 +151,8 @@ Status GraphProperties::InferStatically() {
|
||||
|
||||
if (!node->assigned_device_name().empty()) {
|
||||
device_names_[node->name()] = node->assigned_device_name();
|
||||
} else if (!node->requested_device().empty()) {
|
||||
device_names_[node->name()] = node->requested_device();
|
||||
} else if (!node->def().device().empty()) {
|
||||
device_names_[node->name()] = node->def().device();
|
||||
} else {
|
||||
device_names_[node->name()] = "not set";
|
||||
}
|
||||
|
@ -26,6 +26,9 @@ limitations under the License.
|
||||
#include "cuda/include/cudnn.h"
|
||||
#endif
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -34,6 +37,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -64,6 +69,77 @@ static std::vector<TensorProto> ExtractTensors(const AttrValue& attr_value) {
|
||||
return tensors;
|
||||
}
|
||||
|
||||
static void ExtractExtraProperties(
|
||||
const NodeDef& node,
|
||||
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
||||
std::vector<OpInfo::TensorProperties>* extra_inputs,
|
||||
protobuf::Map<string, AttrValue>* attr_map) {
|
||||
OpRegistry* op_registry = OpRegistry::Global();
|
||||
const OpDef* op_def;
|
||||
auto s = op_registry->LookUpOpDef(node.op(), &op_def);
|
||||
if (!s.ok()) {
|
||||
op_def = nullptr;
|
||||
}
|
||||
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
const string input_name = node.input(i);
|
||||
CHECK(!input_name.empty());
|
||||
TensorId input_tensor_id = ParseTensorName(input_name);
|
||||
const string input_node_name = input_tensor_id.first.ToString();
|
||||
|
||||
auto iter = name_to_node.find(input_node_name);
|
||||
if (iter == name_to_node.end()) continue;
|
||||
const NodeDef* input_node = iter->second;
|
||||
|
||||
// The value attribute in Const input is useful for cost prediction.
|
||||
if (input_node->op() == "Const") {
|
||||
auto it = input_node->attr().find("value");
|
||||
if (it == input_node->attr().end()) continue;
|
||||
|
||||
const AttrValue& attr_value = it->second;
|
||||
std::vector<TensorProto> tensors = ExtractTensors(attr_value);
|
||||
if (tensors.empty()) continue;
|
||||
|
||||
const TensorProto& t = tensors[0];
|
||||
OpInfo::TensorProperties input;
|
||||
input.set_dtype(t.dtype());
|
||||
*(input.mutable_shape()) = t.tensor_shape();
|
||||
*(input.mutable_value()) = t;
|
||||
extra_inputs->push_back(input);
|
||||
|
||||
// For filename input, the file size can also be useful.
|
||||
if (op_def &&
|
||||
op_def->input_arg(i).name().find("filename") != std::string::npos) {
|
||||
Tensor tensor;
|
||||
CHECK(tensor.FromProto(t));
|
||||
const string filename = tensor.scalar<string>()();
|
||||
|
||||
Env* env = Env::Default();
|
||||
FileStatistics stat;
|
||||
Status s = env->Stat(filename, &stat);
|
||||
if (s.ok()) {
|
||||
AttrValue attr;
|
||||
attr.set_i(stat.length);
|
||||
string attr_key = strings::StrCat("input_", i, "_filesize");
|
||||
(*attr_map)[attr_key] = attr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// When the input is a handle (e.g. look up table handle), the information
|
||||
// in the op itself is not sufficient to predict the op memory.
|
||||
if (op_def &&
|
||||
op_def->input_arg(i).name().find("handle") != std::string::npos) {
|
||||
string new_key = strings::StrCat("parent_", i, "_op");
|
||||
AttrValue attr;
|
||||
attr.set_s(input_node->op());
|
||||
(*attr_map)[new_key] = attr;
|
||||
// TODO(yuefengz): Only parent node's op name is copied. Copy inputs
|
||||
// and attributes when necessary.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<OpInfo::TensorProperties> FindInputFeatures(
|
||||
const NodeDef& node,
|
||||
const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost,
|
||||
@ -80,35 +156,6 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
|
||||
continue;
|
||||
}
|
||||
|
||||
auto iter = name_to_node.find(input_node_name);
|
||||
if (iter != name_to_node.end()) {
|
||||
const NodeDef* node = iter->second;
|
||||
if (node->op() == "Const") {
|
||||
auto it = node->attr().find("value");
|
||||
if (it == node->attr().end()) {
|
||||
inputs.push_back(UnknownInput());
|
||||
continue;
|
||||
}
|
||||
|
||||
const AttrValue& attr_value = it->second;
|
||||
std::vector<TensorProto> tensors = ExtractTensors(attr_value);
|
||||
|
||||
if (tensors.empty()) {
|
||||
inputs.push_back(UnknownInput());
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto& t : tensors) {
|
||||
OpInfo::TensorProperties input;
|
||||
input.set_dtype(t.dtype());
|
||||
*(input.mutable_shape()) = t.tensor_shape();
|
||||
*(input.mutable_value()) = t;
|
||||
inputs.push_back(input);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto it = name_to_cost.find(input_node_name);
|
||||
if (it == name_to_cost.end() || output_index < 0) {
|
||||
inputs.push_back(UnknownInput());
|
||||
@ -126,9 +173,9 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
|
||||
return inputs;
|
||||
}
|
||||
|
||||
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
|
||||
DeviceProperties GetDeviceInfo(const string& device_str) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (DeviceNameUtils::ParseFullName(node.device(), &parsed)) {
|
||||
if (DeviceNameUtils::ParseFullName(device_str, &parsed)) {
|
||||
if (parsed.type == "GPU") {
|
||||
return GetLocalGPUInfo(parsed.id);
|
||||
} else if (parsed.type == "CPU") {
|
||||
@ -140,5 +187,31 @@ DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
|
||||
return device;
|
||||
}
|
||||
|
||||
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
|
||||
return GetDeviceInfo(node.device());
|
||||
}
|
||||
|
||||
OpInfo BuildOpInfo(
|
||||
const NodeDef& node, const string& device_str,
|
||||
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
||||
const std::vector<OpInfo::TensorProperties>& inputs) {
|
||||
OpInfo op_info;
|
||||
op_info.set_op(node.op());
|
||||
*op_info.mutable_attr() = node.attr();
|
||||
*op_info.mutable_device() = GetDeviceInfo(device_str);
|
||||
for (auto& input : inputs) {
|
||||
*op_info.add_inputs() = input;
|
||||
}
|
||||
|
||||
std::vector<OpInfo::TensorProperties> extra_inputs;
|
||||
ExtractExtraProperties(node, name_to_node, &extra_inputs,
|
||||
op_info.mutable_attr());
|
||||
for (auto& input : extra_inputs) {
|
||||
*op_info.add_inputs() = input;
|
||||
}
|
||||
|
||||
return op_info;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/graph/types.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -42,6 +43,15 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
|
||||
|
||||
// Returns the DeviceProperties of the device on which 'node' runs.
|
||||
DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
|
||||
DeviceProperties GetDeviceInfo(const string& device_str);
|
||||
|
||||
// Builds the OpInfo proto for node, given all nodes in the graph, the node's
|
||||
// device and its input properties which are typically built by shape inference
|
||||
// or calling FindInputFeatures.
|
||||
OpInfo BuildOpInfo(
|
||||
const NodeDef& node, const string& device_str,
|
||||
const std::unordered_map<string, const NodeDef*>& name_to_node,
|
||||
const std::vector<OpInfo::TensorProperties>& inputs);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user