Automated rollback of change 137740850

Change: 137747341
This commit is contained in:
A. Unique TensorFlower 2016-10-31 13:11:56 -08:00 committed by TensorFlower Gardener
parent 1dd44f3ecc
commit 41734d78d3
28 changed files with 714 additions and 162 deletions

View File

@ -94,6 +94,38 @@ function(RELATIVE_PROTOBUF_GENERATE_PYTHON ROOT_DIR SRCS)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR)
if(NOT ARGN)
message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_CPP() called without any proto files")
return()
endif()
set(${SRCS})
set(${HDRS})
foreach(FIL ${ARGN})
set(ABS_FIL ${ROOT_DIR}/${FIL})
get_filename_component(FIL_WE ${FIL} NAME_WE)
get_filename_component(FIL_DIR ${ABS_FIL} PATH)
file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc")
list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h")
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc"
"${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h"
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS}
DEPENDS ${ABS_FIL} protobuf
COMMENT "Running C++ protocol buffer compiler on ${FIL}"
VERBATIM )
endforeach()
set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
set(${HDRS} ${${HDRS}} PARENT_SCOPE)
endfunction()
file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/*.proto"
"${tensorflow_source_dir}/tensorflow/python/*.proto"
@ -102,6 +134,12 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON(
${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs}
)
RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
${tensorflow_source_dir} ${tf_protos_python_srcs}
)
add_library(tf_python_protos_cc ${PROTO_SRCS} ${PROTO_HDRS})
# tf_python_touchup_modules adds empty __init__.py files to all
# directories containing Python code, so that Python will recognize
# them as modules.
@ -201,6 +239,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name)
)
target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE
tf_protos_cc
tf_python_protos_cc
${tensorflow_EXTERNAL_LIBRARIES}
)
@ -312,6 +351,7 @@ target_link_libraries(pywrap_tensorflow
${tf_core_gpu_kernels_lib}
${tensorflow_EXTERNAL_LIBRARIES}
tf_protos_cc
tf_python_protos_cc
${PYTHON_LIBRARIES}
)

View File

@ -379,6 +379,7 @@ tf_gen_op_libs(
"no_op",
"parsing_ops",
"random_ops",
"resource_variable_ops",
"sdca_ops",
"script_ops",
"sendrecv_ops",
@ -542,6 +543,7 @@ cc_library(
"//tensorflow/core/kernels:parsing",
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:required",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",

View File

@ -42,6 +42,8 @@ Status ShapeRefiner::AddNode(const Node* node) {
// indexed by 'node's input.
std::vector<Node*> input_nodes(node->num_inputs());
std::vector<ShapeHandle> input_shapes(node->num_inputs());
std::vector<DataType> input_handle_dtypes(node->num_inputs());
std::vector<ShapeHandle> input_handle_shapes(node->num_inputs());
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
@ -57,6 +59,15 @@ Status ShapeRefiner::AddNode(const Node* node) {
DCHECK_GE(e->dst_input(), 0);
input_nodes[e->dst_input()] = input;
input_shapes[e->dst_input()] = c->output(e->src_output());
// Only propagate handle xshape and dtype of edges which are carrying
// resource handles.
if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
input_handle_dtypes[e->dst_input()] =
c->output_handle_dtype(e->src_output());
input_handle_shapes[e->dst_input()] =
c->output_handle_shape(e->src_output());
}
}
// Get the shape function for this node
@ -76,9 +87,9 @@ Status ShapeRefiner::AddNode(const Node* node) {
std::vector<ShapeHandle> input_tensors_as_shapes;
// Create the inference context for this node with the existing input shapes.
std::unique_ptr<InferenceContext> c(
new InferenceContext(&node->def(), node->op_def(), input_shapes,
input_tensors, input_tensors_as_shapes));
std::unique_ptr<InferenceContext> c(new InferenceContext(
&node->def(), node->op_def(), input_shapes, input_tensors,
input_tensors_as_shapes, input_handle_shapes, input_handle_dtypes));
if (!c->construction_status().ok()) {
return c->construction_status();
}

View File

@ -56,7 +56,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
.Input({{"data", 0, DT_FLOAT}})
.Finalize(&def));
InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {});
InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {}, {}, {});
TF_EXPECT_OK(NoOutputs(&c));
EXPECT_EQ(0, c.num_outputs());
}
@ -74,14 +74,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
{
InferenceContext c(&def, op_def, {S({})}, {}, {});
InferenceContext c(&def, op_def, {S({})}, {}, {}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
}
{
InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {});
InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {}, {}, {});
TF_EXPECT_OK(ScalarShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(0, c.Rank(output));
@ -108,7 +108,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Finalize(&def));
{
InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown inner dimension for one
InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {});
InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -126,7 +126,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Invalid rank.
InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {});
InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -136,7 +136,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Unknown outer dimension
InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {}, {}, {});
TF_EXPECT_OK(MatMulShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -145,7 +145,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -156,7 +156,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
{
// Inner shapes not compatible
InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {},
{});
auto s = MatMulShape(&c);
EXPECT_FALSE(s.ok());
EXPECT_TRUE(
@ -174,7 +175,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {});
InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -191,7 +192,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
.Attr("type", DT_FLOAT)
.Finalize(&def));
InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {}, {}, {});
auto s = MatMulShape(&c);
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -215,7 +216,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Finalize(&def));
{
InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
@ -224,7 +225,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Unknown ranks.
InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {});
InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_FALSE(c.RankKnown(output));
@ -232,7 +233,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Rank > 2
InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {});
InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {},
{});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
@ -245,7 +247,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
@ -258,8 +260,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {},
{});
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {},
{}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
@ -272,7 +274,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {});
InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {}, {},
{});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[10,11,12]", c.DebugString(output));
@ -280,7 +283,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
{
// Input rank not high enough
InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {});
InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
@ -292,7 +295,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {}, {}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@ -311,7 +314,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Finalize(&def));
{
InferenceContext c(&def, op_def, {S({2, 10})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 10})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -319,7 +322,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Rank > 2
InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {});
InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -331,7 +334,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@ -343,7 +346,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {});
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {},
{});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
@ -355,7 +359,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {});
InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {}, {}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
@ -363,7 +367,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
{
// Input rank not high enough
InferenceContext c(&def, op_def, {S({3})}, {}, {});
InferenceContext c(&def, op_def, {S({3})}, {}, {}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
@ -374,7 +378,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
InferenceContext c(&def, op_def, {S({2, 3})}, {}, {});
InferenceContext c(&def, op_def, {S({2, 3})}, {}, {}, {}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}

View File

@ -66,6 +66,14 @@ void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
AttrValueMap::value_type(name.ToString(), attr_value));
}
// Adds an attr to an attr value map.
template <class T>
void AddAttr(StringPiece name, T&& value, AttrValueMap* map) {
AttrValue attr_value;
SetAttrValue(value, &attr_value);
map->insert(AttrValueMap::value_type(name.ToString(), attr_value));
}
class AttrSlice {
public:
AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <typeinfo>
#include <unordered_map>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"

View File

@ -31,7 +31,9 @@ InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes)
const std::vector<ShapeHandle>& input_tensors_as_shapes,
const std::vector<TensorShapeProto>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
@ -43,19 +45,30 @@ InferenceContext::InferenceContext(
}
inputs_.push_back(shape);
}
PostInputInit();
std::vector<ShapeHandle> handle_shapes;
for (const auto& p : input_handle_shapes) {
ShapeHandle shape;
construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
if (!construction_status_.ok()) {
return;
}
handle_shapes.push_back(shape);
}
PostInputInit(handle_shapes, input_handle_dtypes);
}
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes)
const std::vector<ShapeHandle>& input_tensors_as_shapes,
const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes)
: node_def_(*CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
PostInputInit();
PostInputInit(input_handle_shapes, input_handle_dtypes);
}
InferenceContext::~InferenceContext() {
@ -124,15 +137,44 @@ void InferenceContext::PreInputInit(
for (int i = 0; i < num_outputs; ++i) {
outputs_.push_back(nullptr);
}
output_handle_shape_.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
output_handle_shape_.push_back(UnknownShape());
}
output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID);
}
void InferenceContext::PostInputInit() {
void InferenceContext::PostInputInit(
const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes) {
int num_inputs_from_node_def = 0;
for (const auto& e : input_name_map_) {
num_inputs_from_node_def =
std::max(num_inputs_from_node_def, e.second.second);
}
// Allow passing empty shapes/dtypes to avoid changing every single test.
if (input_handle_shapes.empty()) {
input_handle_shape_.resize(inputs_.size());
} else {
input_handle_shape_ = input_handle_shapes;
if (input_handle_shape_.size() != inputs_.size()) {
construction_status_ = errors::InvalidArgument(
"Wrong number of handle shapes passed; expected ", inputs_.size(),
" got ", input_handle_shape_.size());
}
}
if (input_handle_dtypes.empty()) {
input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID);
} else {
input_handle_dtype_ = input_handle_dtypes;
if (input_handle_dtype_.size() != inputs_.size()) {
construction_status_ = errors::InvalidArgument(
"Wrong number of handle dtypes passed; expected ", inputs_.size(),
" got ", input_handle_dtype_.size());
}
}
if (inputs_.size() != num_inputs_from_node_def) {
construction_status_ = errors::InvalidArgument(
"Wrong number of inputs passed: ", inputs_.size(), " while ",
@ -737,6 +779,13 @@ Status InferenceContext::AttachContext(const Status& status) {
strings::StrCat(status.error_message(), error_context));
}
ShapeHandle InferenceContext::input_handle_shape(int idx) {
if (!input_handle_shape_[idx].IsSet()) {
input_handle_shape_[idx] = UnknownShape();
}
return input_handle_shape_[idx];
}
// -----------------------------------------------------------------------------
// ShapeManager
// -----------------------------------------------------------------------------

View File

@ -147,7 +147,9 @@ class InferenceContext {
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<ShapeHandle>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
const std::vector<ShapeHandle>& input_tensors_as_shapes,
const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes);
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
//
@ -162,7 +164,9 @@ class InferenceContext {
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
const std::vector<TensorShapeProto>& input_shapes,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
const std::vector<ShapeHandle>& input_tensors_as_shapes,
const std::vector<TensorShapeProto>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes);
~InferenceContext();
@ -231,12 +235,12 @@ class InferenceContext {
}
return s->dims_[idx];
}
int32 Rank(ShapeHandle s) { return s->rank_; }
bool RankKnown(ShapeHandle s) { return Rank(s) != kUnknownRank; }
inline int64 Value(DimensionOrConstant d) {
int32 Rank(ShapeHandle s) const { return s->rank_; }
bool RankKnown(ShapeHandle s) const { return Rank(s) != kUnknownRank; }
inline int64 Value(DimensionOrConstant d) const {
return d.dim.IsSet() ? d.dim->value_ : d.val;
}
inline bool ValueKnown(DimensionOrConstant d) {
inline bool ValueKnown(DimensionOrConstant d) const {
return Value(d) != kUnknownDim;
}
@ -391,6 +395,30 @@ class InferenceContext {
Status construction_status() const { return construction_status_; }
// Methods to propagate shape and dtype on edges of handles. Handles are the
// dtype DT_RESOURCE which can be used to access state stored in a
// ResourceManager. When ops (such as variables) consume these handles to
// produce tensors they might need to know side-information about the shapes
// and dtypes of tensors which can be accessed via the handle. These methods
// propagate that information. Output handle dtypes and shapes are ignored if
// the output tensor is not of type DT_RESOURCE.
ShapeHandle input_handle_shape(int idx);
DataType input_handle_dtype(int idx) const {
return input_handle_dtype_[idx];
}
void set_output_handle_shape(int idx, ShapeHandle shape) {
output_handle_shape_[idx] = shape;
}
void set_output_handle_dtype(int idx, DataType dtype) {
output_handle_dtype_[idx] = dtype;
}
ShapeHandle output_handle_shape(int idx) const {
return output_handle_shape_[idx];
}
DataType output_handle_dtype(int idx) const {
return output_handle_dtype_[idx];
}
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(ShapeHandle indices_shape,
@ -481,7 +509,8 @@ class InferenceContext {
void PreInputInit(const OpDef& op_def,
const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes);
void PostInputInit();
void PostInputInit(const std::vector<ShapeHandle>& input_handle_shapes,
const std::vector<DataType>& input_handle_dtypes);
DimensionHandle GetDimension(const DimensionOrConstant& d);
@ -510,6 +539,11 @@ class InferenceContext {
std::vector<ShapeHandle> input_tensors_as_shapes_;
std::vector<bool> requested_input_tensor_as_partial_shape_;
std::vector<ShapeHandle> input_handle_shape_;
std::vector<DataType> input_handle_dtype_;
std::vector<ShapeHandle> output_handle_shape_;
std::vector<DataType> output_handle_dtype_;
const NodeDef& node_def_;
NameRangeMap input_name_map_;
NameRangeMap output_name_map_;

View File

@ -71,7 +71,8 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
.Attr("N", 3)
.Input(FakeInput(DT_FLOAT))
.Finalize(&def);
InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {});
InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {},
{}, {});
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
@ -107,7 +108,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}, {});
EXPECT_EQ(InferenceContext::kUnknownDim,
c.Value(InferenceContext::kUnknownDim));
EXPECT_EQ(1, c.Value(1));
@ -122,7 +123,7 @@ TEST_F(ShapeInferenceTest, Run) {
NodeDef def;
def.set_name("foo");
def.set_op("foo_op");
InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {});
InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {}, {}, {});
{
auto fn = [](InferenceContext* c) {
@ -154,7 +155,7 @@ TEST_F(ShapeInferenceTest, Run) {
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})},
{}, {});
{}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@ -195,7 +196,8 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
TEST_F(ShapeInferenceTest, NumElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2),
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {});
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {},
{});
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
@ -208,7 +210,8 @@ TEST_F(ShapeInferenceTest, NumElements) {
TEST_F(ShapeInferenceTest, WithRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
{}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -246,7 +249,8 @@ TEST_F(ShapeInferenceTest, WithRank) {
TEST_F(ShapeInferenceTest, WithRankAtMost) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
{}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -284,7 +288,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
{}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -322,7 +327,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
TEST_F(ShapeInferenceTest, WithValue) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}, {});
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@ -363,7 +368,8 @@ TEST_F(ShapeInferenceTest, WithValue) {
TEST_F(ShapeInferenceTest, MergeDim) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {}, {},
{});
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@ -412,7 +418,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
InferenceContext c(&def, MakeOpDef(7, 2),
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
Unknown(), S({1})},
{}, {});
{}, {}, {}, {});
auto s_unknown = c.input(0);
auto s_1_2 = c.input(1);
@ -483,7 +489,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
{
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
},
{}, {});
{}, {}, {}, {});
auto s_unknown = c.input(0);
auto s_u_2 = c.input(1);
@ -536,7 +542,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
TEST_F(ShapeInferenceTest, Subshape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()},
{}, {});
{}, {}, {}, {});
ShapeHandle unknown = c.input(1);
ShapeHandle out;
@ -611,7 +617,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
TEST_F(ShapeInferenceTest, Concatenate) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2),
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {});
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {});
auto in0 = c.input(0);
auto in1 = c.input(1);
@ -637,7 +643,8 @@ TEST_F(ShapeInferenceTest, Concatenate) {
TEST_F(ShapeInferenceTest, ReplaceDim) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {});
InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {},
{}, {});
auto in = c.input(0);
auto unknown = c.input(1);
@ -668,7 +675,8 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
TEST_F(ShapeInferenceTest, MakeShape) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}, {},
{});
std::vector<DimensionHandle> dims;
auto in0 = c.input(0);
@ -693,7 +701,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
TEST_F(ShapeInferenceTest, UnknownShape) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto u0 = c.UnknownShape();
auto u1 = c.UnknownShape();
@ -705,7 +713,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
TEST_F(ShapeInferenceTest, Scalar) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Scalar();
EXPECT_EQ("[]", c.DebugString(s0));
@ -716,7 +724,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
TEST_F(ShapeInferenceTest, Vector) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Vector(1);
EXPECT_EQ("[1]", c.DebugString(s0));
@ -732,7 +740,7 @@ TEST_F(ShapeInferenceTest, Vector) {
TEST_F(ShapeInferenceTest, Matrix) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto s0 = c.Matrix(1, 2);
EXPECT_EQ("[1,2]", c.DebugString(s0));
@ -754,7 +762,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
auto create = [&](Tensor* t) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {});
InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}, {});
ShapeHandle out;
Status s = c.MakeShapeFromShapeTensor(0, &out);
if (s.ok()) {
@ -806,7 +814,8 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
// Test when the input shape is wrong.
{
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {});
InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}, {},
{});
ShapeHandle out;
EXPECT_EQ("Shape must be rank 1 but is rank 2",
c.MakeShapeFromShapeTensor(0, &out).error_message());
@ -816,7 +825,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
TensorShapeProto proto;
// With a set unknown rank.
@ -852,7 +861,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
TEST_F(ShapeInferenceTest, MakeDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto d0 = c.MakeDim(1);
auto d1 = c.MakeDim(1);
@ -866,7 +875,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
TEST_F(ShapeInferenceTest, UnknownDim) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto d0 = c.UnknownDim();
auto d1 = c.UnknownDim();
@ -878,7 +887,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
@ -892,7 +901,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
{&t1, &t2}, {});
{&t1, &t2}, {}, {}, {});
EXPECT_TRUE(c.input_tensor(0) == &t1);
EXPECT_TRUE(c.input_tensor(1) == &t2);
@ -903,7 +912,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
NodeDef def;
InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {});
InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {}, {},
{});
DimensionHandle d;
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
@ -934,7 +944,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
.ok());
std::vector<ShapeHandle> empty;
InferenceContext c(&def, op_reg_data.op_def, empty, {}, {});
InferenceContext c(&def, op_reg_data.op_def, empty, {}, {}, {}, {});
string value;
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
EXPECT_EQ("bar", value);
@ -942,7 +952,8 @@ TEST_F(ShapeInferenceTest, GetAttr) {
TEST_F(ShapeInferenceTest, Divide) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}, {},
{});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1004,7 +1015,7 @@ TEST_F(ShapeInferenceTest, Divide) {
TEST_F(ShapeInferenceTest, Add) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1055,7 +1066,7 @@ TEST_F(ShapeInferenceTest, Add) {
TEST_F(ShapeInferenceTest, Subtract) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1104,7 +1115,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
TEST_F(ShapeInferenceTest, Multiply) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_6 = c.Dim(s, 0);
@ -1157,7 +1168,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
TEST_F(ShapeInferenceTest, FullyDefined) {
NodeDef def;
std::vector<ShapeHandle> empty;
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
// No rank or missing dimension information should return false.
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
@ -1170,7 +1181,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
TEST_F(ShapeInferenceTest, Min) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@ -1218,7 +1229,7 @@ TEST_F(ShapeInferenceTest, Min) {
TEST_F(ShapeInferenceTest, Max) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {});
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, {}, {});
auto s = c.input(0);
auto d_1 = c.Dim(s, 0);
@ -1256,7 +1267,7 @@ TEST_F(ShapeInferenceTest, Max) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
{}, {});
{}, {}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1269,7 +1280,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
{});
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1281,8 +1292,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {},
{});
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1295,8 +1306,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {},
{});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1309,8 +1320,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {},
{});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1324,7 +1335,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
{});
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1337,7 +1348,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
{});
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1350,7 +1361,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
{});
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1363,7 +1374,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
{});
{}, {}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());
@ -1375,8 +1386,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
NodeDef def;
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {},
{});
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
{}, {});
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(1, c.num_outputs());

View File

@ -44,8 +44,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
}
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
in_shapes, op.input_tensors,
{} /* input_tensors_as_shapes */);
in_shapes, op.input_tensors, {}, {}, {});
TF_RETURN_IF_ERROR(c.construction_status());
if (op_reg_data->shape_inference_fn == nullptr) {
return errors::InvalidArgument(

View File

@ -435,6 +435,7 @@ class Tensor {
friend class VariableOp; // For access to set_shape
friend class AutoReloadVariableOp; // For access to set_shape
friend class TensorTestHelper; // For access to set_shape
friend class CreateVariableOp;
// Creates a tensor with the input datatype, shape and buf.
//

View File

@ -1030,6 +1030,18 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "resource_variable_ops",
srcs = ["resource_variable_ops.cc"],
deps = [
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:resource_variable_ops_op_lib",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "fact_op",
prefix = "fact_op",

View File

@ -0,0 +1,49 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
class CreateVariableOp : public OpKernel {
public:
CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
}
void Compute(OpKernelContext* c) override {
Var* var = new Var(dtype_);
var->Ref();
core::ScopedUnref ur(var);
OP_REQUIRES_OK(c, CreateResource<Var>(c, HandleFromInput(c, 0), var));
// TODO(apassos): this currently does not initialize the tensor, so it's
// pointless, other than checking construction in tests. Fix this.
}
private:
DataType dtype_;
};
REGISTER_KERNEL_BUILDER(Name("CreateVariableOp").Device(DEVICE_CPU),
CreateVariableOp);
} // namespace tensorflow

View File

@ -26,6 +26,26 @@ limitations under the License.
namespace tensorflow {
// Resource stored by variables in the resource manager.
class Var : public ResourceBase {
public:
explicit Var(DataType dtype) : tensor_(dtype) {}
mutex* mu() { return &mu_; }
Tensor* tensor() { return &tensor_; }
string DebugString() override {
return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
tensor_.shape().DebugString());
}
private:
mutex mu_;
Tensor tensor_;
~Var() override {}
TF_DISALLOW_COPY_AND_ASSIGN(Var);
};
class VariableOp : public OpKernel {
public:
explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) {
@ -59,25 +79,6 @@ class VariableOp : public OpKernel {
}
private:
class Var : public ResourceBase {
public:
explicit Var(DataType dtype) : tensor_(dtype) {}
mutex* mu() { return &mu_; }
Tensor* tensor() { return &tensor_; }
string DebugString() override {
return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
tensor_.shape().DebugString());
}
private:
mutex mu_;
Tensor tensor_;
~Var() override {}
TF_DISALLOW_COPY_AND_ASSIGN(Var);
};
DataType dtype_;
TensorShape shape_;

View File

@ -1252,7 +1252,12 @@ REGISTER_OP("Identity")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.SetShapeFn(shape_inference::UnchangedShape)
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
c->set_output_handle_shape(0, c->input_handle_shape(0));
return Status::OK();
})
.Doc(R"Doc(
Return a tensor with the same shape and contents as the input tensor or value.
)Doc");

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -153,6 +155,21 @@ TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) {
INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
}
TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
const char* op_name = "Identity";
ShapeInferenceTestOp op(op_name);
// Check that handle dtypes are preserved.
const OpRegistrationData* op_reg_data;
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
{TensorShapeProto()}, {}, {}, {},
{DT_BOOL});
TF_ASSERT_OK(c.construction_status());
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
EXPECT_TRUE(c.output_handle_dtype(0) == DT_BOOL);
}
TEST(ArrayOpsTest, Diag_ShapeFn) {
ShapeInferenceTestOp op("Diag");
INFER_OK(op, "?", "?");

View File

@ -911,6 +911,19 @@ REGISTER_OP("Select")
.Output("output: T")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
// Merge handle shape and dtype if applicable.
if (c->input_handle_dtype(1) != c->input_handle_dtype(2)) {
// TODO(apassos) resolve this in the manner of b/32476923
return errors::InvalidArgument(
"Trying to merge handles pointing to different dtypes.");
}
c->set_output_handle_dtype(0, c->input_handle_dtype(1));
ShapeHandle output_handle_shape;
TF_RETURN_IF_ERROR(c->Merge(c->input_handle_shape(1),
c->input_handle_shape(2),
&output_handle_shape));
c->set_output_handle_shape(0, output_handle_shape);
// The inputs 'then' and 'else' must have the same shape.
ShapeHandle data = c->input(1);
ShapeHandle other = c->input(2);
@ -961,8 +974,9 @@ REGISTER_OP("Select")
}
c->set_output(0, data);
return Status::OK();
})
})
.Doc(R"doc(
Selects elements from `t` or `e`, depending on `condition`.

View File

@ -216,6 +216,37 @@ TEST(MathOpsTest, Select_ShapeFn) {
INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
"[2,?,5];[?,?,3];[?,2,?]");
// Test that handle shapes were merged.
const OpRegistrationData* op_reg_data;
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
TensorShapeProto i0;
i0.add_dim()->set_size(1);
i0.add_dim()->set_size(-1);
TensorShapeProto i1;
i1.add_dim()->set_size(-1);
i1.add_dim()->set_size(2);
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
shape_inference::InferenceContext c(
&op.node_def, op_reg_data->op_def,
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
{TensorShapeProto(), i0, i1}, {});
TF_ASSERT_OK(c.construction_status());
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
EXPECT_TRUE(c.FullyDefined(c.output_handle_shape(0)));
EXPECT_EQ("[1,2]", c.DebugString(c.output_handle_shape(0)));
// Expect an error when the shapes can't be merged.
TensorShapeProto i2;
i1.add_dim()->set_size(2);
i1.add_dim()->set_size(2);
shape_inference::InferenceContext c2(
&op.node_def, op_reg_data->op_def,
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
{TensorShapeProto(), i0, i2}, {});
TF_ASSERT_OK(c.construction_status());
EXPECT_FALSE(c2.Run(op_reg_data->shape_inference_fn).ok());
}
TEST(MathOpsTest, Range_ShapeFn) {

View File

@ -0,0 +1,80 @@
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================================
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("VarHandleOp")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("dtype: type")
.Attr("shape: shape")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType t;
c->GetAttr("dtype", &t);
c->set_output_handle_dtype(0, t);
TensorShapeProto p;
c->GetAttr("shape", &p);
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(p, &s));
c->set_output_handle_shape(0, s);
return Status::OK();
})
.Doc(R"(
Creates a handle to a Variable resource.
container: the container this variable is placed in.
shared_name: the name by which this variable is referred to.
dtype: the type of this variable. Must agree with the dtypes
of all ops using this variable.
shape: The (possibly partially specified) shape of this variable.
)");
REGISTER_OP("CreateVariableOp")
.Input("resource: resource")
.Input("value: dtype")
.Attr("dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType handle_dtype = c->input_handle_dtype(0);
DataType value_dtype;
c->GetAttr("dtype", &value_dtype);
if (handle_dtype != value_dtype) {
return errors::InvalidArgument(
"Trying to initialize handle for variable with wrong dtype. "
"Expected ",
handle_dtype, " got ", value_dtype);
}
shape_inference::ShapeHandle s = c->input_handle_shape(0);
shape_inference::ShapeHandle value_shape = c->input(1);
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
return Status::OK();
})
.Doc(R"(
Creates a variable resource.
resource: handle to the resource in which to store the variable.
value: the value to set the new tensor to use.
dtype: the dtype of the value.
)");
} // namespace tensorflow

View File

@ -217,23 +217,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "cpp_shape_inference",
srcs = ["framework/cpp_shape_inference.cc"],
hdrs = ["framework/cpp_shape_inference.h"],
copts = ["-Wno-sign-compare"],
visibility = ["//visibility:public"],
deps = [
":numpy_lib",
":py_func_lib",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core:protos_cc",
"//third_party/py/numpy:headers",
"//util/python:python_headers",
],
)
cc_library(
name = "python_op_gen_main",
srcs = [
@ -284,6 +267,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":cpp_shape_inference_proto_py",
":framework_for_generated_wrappers",
":pywrap_tensorflow",
],
@ -660,6 +644,11 @@ tf_gen_op_wrapper_private_py(
require_shape_functions = True,
)
tf_gen_op_wrapper_private_py(
name = "resource_variable_ops_gen",
require_shape_functions = True,
)
tf_gen_op_wrapper_private_py(
name = "script_ops_gen",
require_shape_functions = True,
@ -989,6 +978,16 @@ py_library(
],
)
py_library(
name = "resource_variable_ops",
srcs = ["ops/resource_variable_ops.py"],
srcs_version = "PY2AND3",
deps = [
":framework",
":resource_variable_ops_gen",
],
)
py_library(
name = "nn",
srcs = ["ops/nn.py"],
@ -1409,6 +1408,7 @@ py_library(
":partitioned_variables",
":random_ops",
":random_ops_gen",
":resource_variable_ops",
":resources",
":rnn",
":rnn_cell",
@ -1690,6 +1690,7 @@ tf_proto_library(
["**/*.proto"],
exclude = [
"util/protobuf/compare_test.proto",
"framework/cpp_shape_inference.proto",
],
),
go_api_version = 2,
@ -1701,6 +1702,13 @@ tf_proto_library_py(
srcs = ["util/protobuf/compare_test.proto"],
)
tf_proto_library(
name = "cpp_shape_inference_proto",
srcs = ["framework/cpp_shape_inference.proto"],
cc_api_version = 2,
cc_libs = ["//tensorflow/core:protos_all_cc"],
)
py_test(
name = "protobuf_compare_test",
size = "small",
@ -1767,6 +1775,24 @@ py_library(
],
)
cc_library(
name = "cpp_shape_inference",
srcs = ["framework/cpp_shape_inference.cc"],
hdrs = ["framework/cpp_shape_inference.h"],
copts = ["-Wno-sign-compare"],
visibility = ["//visibility:public"],
deps = [
":cpp_shape_inference_proto_cc",
":numpy_lib",
":py_func_lib",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core:protos_cc",
"//third_party/py/numpy:headers",
"//util/python:python_headers",
],
)
cuda_py_tests(
name = "device_lib_test",
size = "small",

View File

@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
import six.moves
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@ -567,8 +567,12 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
the C++ shape function.
Returns:
A TensorShape list of the output shapes of the op, as computed using the
C++ shape inference function registered for the op.
A dictionary with the following keys:
shapes: A TensorShape list of the output shapes of the op, as computed
using the C++ shape inference function registered for the op.
handle_shapes: A TensorShape list of the shapes for handle outputs, if
any.
handle_dtypes: A list of DataType enums for the handle outputs, if any.
Raises:
ValueError: If the C++ shape function returned an error (e.g. because the
@ -576,8 +580,16 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
according to the shape function).
"""
node_def_str = op.node_def.SerializeToString()
input_shapes = [i.get_shape().as_proto().SerializeToString() for i in
op.inputs]
def tensor_to_inference_result(t):
r = cpp_shape_inference_pb2.CppShapeInferenceResult()
r.shape.CopyFrom(t.get_shape().as_proto())
# pylint: disable=protected-access
r.handle_shape.CopyFrom(t._handle_shape)
r.handle_dtype = t._handle_dtype
# pylint: enable=protected-access
return r.SerializeToString()
input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
input_tensors = [None for i in input_shapes]
if input_tensors_needed:
@ -596,10 +608,13 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
raise ValueError(err.message)
# Convert TensorShapeProto values in output_shapes.
result = [
tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s))
result_protos = [
cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
for s in output_shapes
]
result = [r.shape for r in result_protos]
result_handle_shapes = [r.handle_shape for r in result_protos]
result_handle_dtypes = [r.handle_dtype for r in result_protos]
if debug_python_shape_fn:
try:
@ -616,4 +631,6 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
str(result), str(python_result), str(op.node_def),
",".join([str(i.get_shape()) for i in op.inputs])))
return result
return {"shapes": result,
"handle_shapes": result_handle_shapes,
"handle_dtypes": result_handle_dtypes}

View File

@ -20,12 +20,32 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
#include "tensorflow/python/lib/core/py_func.h"
namespace tensorflow {
namespace swig {
namespace {
void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s,
tensorflow::shape_inference::InferenceContext* c,
TensorShapeProto* out) {
if (c->RankKnown(s)) {
const int32 rank = c->Rank(s);
for (int i = 0; i < rank; ++i) {
shape_inference::DimensionHandle d = c->Dim(s, i);
auto* out_dim = out->add_dim();
if (c->ValueKnown(d)) {
out_dim->set_size(c->Value(d));
} else {
out_dim->set_size(-1);
}
}
} else {
out->set_unknown_rank(true);
}
}
Status RunCppShapeInferenceImpl(
const string& serialized_node_def,
const std::vector<string>& input_serialized_shapes,
@ -49,12 +69,21 @@ Status RunCppShapeInferenceImpl(
// Convert input shapes.
std::vector<TensorShapeProto> input_shapes;
std::vector<TensorShapeProto> input_handle_shapes;
std::vector<DataType> input_handle_dtypes;
input_shapes.resize(input_serialized_shapes.size());
input_handle_shapes.resize(input_serialized_shapes.size());
input_handle_dtypes.resize(input_serialized_shapes.size());
CppShapeInferenceResult tmp;
for (int i = 0; i < input_serialized_shapes.size(); ++i) {
if (!input_shapes[i].ParseFromString(input_serialized_shapes[i])) {
tmp.Clear();
if (!tmp.ParseFromString(input_serialized_shapes[i])) {
return errors::InvalidArgument(
"Error parsing shape proto during cpp shape inference");
}
input_shapes[i].Swap(tmp.mutable_shape());
input_handle_dtypes[i] = tmp.handle_dtype();
input_handle_shapes[i].Swap(tmp.mutable_handle_shape());
}
// Convert input tensor values;
@ -73,34 +102,23 @@ Status RunCppShapeInferenceImpl(
}
// Run shape inference.
// TODO(cwhipkey): pass a value for input_tensors_as_shapes.
tensorflow::shape_inference::InferenceContext c(
&node, op_reg_data->op_def, input_shapes, input_tensors,
{} /* input_tensors_as_shapes */);
{} /* input_tensors_as_shapes */, input_handle_shapes,
input_handle_dtypes);
TF_RETURN_IF_ERROR(c.construction_status());
TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
// Convert output shapes.
output_tensor_shape_protos->resize(c.num_outputs());
TensorShapeProto out;
CppShapeInferenceResult out;
for (int i = 0; i < c.num_outputs(); ++i) {
shape_inference::ShapeHandle s = c.output(i);
out.Clear();
if (c.RankKnown(s)) {
const int32 rank = c.Rank(s);
for (int i = 0; i < rank; ++i) {
shape_inference::DimensionHandle d = c.Dim(s, i);
auto* out_dim = out.add_dim();
if (c.ValueKnown(d)) {
out_dim->set_size(c.Value(d));
} else {
out_dim->set_size(-1);
}
}
} else {
out.set_unknown_rank(true);
}
ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape());
ProtoFromShapeHandle(c.output_handle_shape(i), &c,
out.mutable_handle_shape());
out.set_handle_dtype(c.output_handle_dtype(i));
CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
}

View File

@ -36,7 +36,7 @@ namespace swig {
// inference was successful.
//
// On success, <*output_shapes> is populated with the inferred output shapes (as
// serialized TensorShapeProtos).
// serialized CppShapeInferenceResult protos).
// <*output_shapes> must be empty when this function is called.
//
// This is temporary code to be used during the migration

View File

@ -0,0 +1,13 @@
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/framework/tensor_shape.proto";
message CppShapeInferenceResult {
TensorShapeProto shape = 1;
TensorShapeProto handle_shape = 2;
DataType handle_dtype = 3;
}

View File

@ -32,6 +32,8 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
@ -297,6 +299,10 @@ class Tensor(object):
# to easily navigate a computation graph.
self._consumers = []
# Attributes used for C++ shape inference. Not inspected, only forwarded.
self._handle_shape = tensor_shape_pb2.TensorShapeProto()
self._handle_dtype = types_pb2.DT_INVALID
@property
def op(self):
"""The `Operation` that produces this tensor as an output."""
@ -1791,10 +1797,22 @@ def set_shapes_for_outputs(op):
if shapes is None:
raise RuntimeError(
"Shape function for op %s did not return any shapes" % op)
elif isinstance(shapes, dict):
# Returned by call_cpp_shape_fn
shapes_dict = shapes
shapes = shapes_dict["shapes"]
handle_shapes = shapes_dict["handle_shapes"]
handle_dtypes = shapes_dict["handle_dtypes"]
for output, handle_shape, handle_dtype in zip(op.outputs, handle_shapes, handle_dtypes):
# pylint: disable=protected-access
output._handle_shape = handle_shape
output._handle_dtype = handle_dtype
# pylint: enable=protected-access
if len(op.outputs) != len(shapes):
raise RuntimeError(
"Shape function for op %s returned %d shapes but expected %d" %
(op, len(shapes), len(op.outputs)))
"Shape function for op %s returned %d shapes but expected %d %s %s" %
(op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
for output, s in zip(op.outputs, shapes):
output.set_shape(s)

View File

@ -255,6 +255,16 @@ tf_py_test(
additional_deps = ["//tensorflow:tensorflow_py"],
)
tf_py_test(
name = "resource_variable_ops_test",
size = "small",
srcs = ["resource_variable_ops_test.py"],
additional_deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:resource_variable_ops",
],
)
tf_py_test(
name = "save_restore_ops_test",
size = "small",

View File

@ -0,0 +1,51 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.resource_variable_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
def testHandleDtypeShapeMatch(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
with self.assertRaises(ValueError):
resource_variable_ops.create_variable_op(
handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
with self.assertRaises(ValueError):
resource_variable_ops.create_variable_op(
handle, constant_op.constant([0], dtype=dtypes.int32)).run()
resource_variable_ops.create_variable_op(
handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testDtypeSurvivesIdentity(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
id_handle = array_ops.identity(handle)
resource_variable_ops.create_variable_op(
id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,30 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops to use variables as resources."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_resource_variable_ops import *
# pylint: enable=wildcard-import
ops.RegisterShape("VarHandleOp")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("CreateVariableOp")(common_shapes.call_cpp_shape_fn)