Automated rollback of change 137740850
Change: 137747341
This commit is contained in:
parent
1dd44f3ecc
commit
41734d78d3
tensorflow
contrib/cmake
core
BUILD
common_runtime
framework
common_shape_fns_test.ccnode_def_util.hresource_mgr.hshape_inference.ccshape_inference.hshape_inference_test.ccshape_inference_testutil.cctensor.h
kernels
ops
python
@ -94,6 +94,38 @@ function(RELATIVE_PROTOBUF_GENERATE_PYTHON ROOT_DIR SRCS)
|
||||
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR)
|
||||
if(NOT ARGN)
|
||||
message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_CPP() called without any proto files")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(${SRCS})
|
||||
set(${HDRS})
|
||||
foreach(FIL ${ARGN})
|
||||
set(ABS_FIL ${ROOT_DIR}/${FIL})
|
||||
get_filename_component(FIL_WE ${FIL} NAME_WE)
|
||||
get_filename_component(FIL_DIR ${ABS_FIL} PATH)
|
||||
file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})
|
||||
|
||||
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc")
|
||||
list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc"
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h"
|
||||
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
|
||||
ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS}
|
||||
DEPENDS ${ABS_FIL} protobuf
|
||||
COMMENT "Running C++ protocol buffer compiler on ${FIL}"
|
||||
VERBATIM )
|
||||
endforeach()
|
||||
|
||||
set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
|
||||
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
|
||||
set(${HDRS} ${${HDRS}} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir}
|
||||
"${tensorflow_source_dir}/tensorflow/core/*.proto"
|
||||
"${tensorflow_source_dir}/tensorflow/python/*.proto"
|
||||
@ -102,6 +134,12 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON(
|
||||
${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs}
|
||||
)
|
||||
|
||||
RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
|
||||
${tensorflow_source_dir} ${tf_protos_python_srcs}
|
||||
)
|
||||
|
||||
add_library(tf_python_protos_cc ${PROTO_SRCS} ${PROTO_HDRS})
|
||||
|
||||
# tf_python_touchup_modules adds empty __init__.py files to all
|
||||
# directories containing Python code, so that Python will recognize
|
||||
# them as modules.
|
||||
@ -201,6 +239,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name)
|
||||
)
|
||||
target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE
|
||||
tf_protos_cc
|
||||
tf_python_protos_cc
|
||||
${tensorflow_EXTERNAL_LIBRARIES}
|
||||
)
|
||||
|
||||
@ -312,6 +351,7 @@ target_link_libraries(pywrap_tensorflow
|
||||
${tf_core_gpu_kernels_lib}
|
||||
${tensorflow_EXTERNAL_LIBRARIES}
|
||||
tf_protos_cc
|
||||
tf_python_protos_cc
|
||||
${PYTHON_LIBRARIES}
|
||||
)
|
||||
|
||||
|
@ -379,6 +379,7 @@ tf_gen_op_libs(
|
||||
"no_op",
|
||||
"parsing_ops",
|
||||
"random_ops",
|
||||
"resource_variable_ops",
|
||||
"sdca_ops",
|
||||
"script_ops",
|
||||
"sendrecv_ops",
|
||||
@ -542,6 +543,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:parsing",
|
||||
"//tensorflow/core/kernels:random_ops",
|
||||
"//tensorflow/core/kernels:required",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:sdca_ops",
|
||||
"//tensorflow/core/kernels:sparse",
|
||||
"//tensorflow/core/kernels:state",
|
||||
|
@ -42,6 +42,8 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
// indexed by 'node's input.
|
||||
std::vector<Node*> input_nodes(node->num_inputs());
|
||||
std::vector<ShapeHandle> input_shapes(node->num_inputs());
|
||||
std::vector<DataType> input_handle_dtypes(node->num_inputs());
|
||||
std::vector<ShapeHandle> input_handle_shapes(node->num_inputs());
|
||||
for (const Edge* e : node->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
|
||||
@ -57,6 +59,15 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
DCHECK_GE(e->dst_input(), 0);
|
||||
input_nodes[e->dst_input()] = input;
|
||||
input_shapes[e->dst_input()] = c->output(e->src_output());
|
||||
|
||||
// Only propagate handle xshape and dtype of edges which are carrying
|
||||
// resource handles.
|
||||
if (e->src()->output_type(e->src_output()) == DT_RESOURCE) {
|
||||
input_handle_dtypes[e->dst_input()] =
|
||||
c->output_handle_dtype(e->src_output());
|
||||
input_handle_shapes[e->dst_input()] =
|
||||
c->output_handle_shape(e->src_output());
|
||||
}
|
||||
}
|
||||
|
||||
// Get the shape function for this node
|
||||
@ -76,9 +87,9 @@ Status ShapeRefiner::AddNode(const Node* node) {
|
||||
std::vector<ShapeHandle> input_tensors_as_shapes;
|
||||
|
||||
// Create the inference context for this node with the existing input shapes.
|
||||
std::unique_ptr<InferenceContext> c(
|
||||
new InferenceContext(&node->def(), node->op_def(), input_shapes,
|
||||
input_tensors, input_tensors_as_shapes));
|
||||
std::unique_ptr<InferenceContext> c(new InferenceContext(
|
||||
&node->def(), node->op_def(), input_shapes, input_tensors,
|
||||
input_tensors_as_shapes, input_handle_shapes, input_handle_dtypes));
|
||||
if (!c->construction_status().ok()) {
|
||||
return c->construction_status();
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ TEST(CommonShapeFnsTest, NoOutputShapeTest) {
|
||||
.Input({{"data", 0, DT_FLOAT}})
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({}), S({10})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(NoOutputs(&c));
|
||||
EXPECT_EQ(0, c.num_outputs());
|
||||
}
|
||||
@ -74,14 +74,14 @@ TEST(CommonShapeFnsTest, ScalarShapeTest) {
|
||||
NodeDefBuilder("test", "L2Loss").Input("t", 0, DT_FLOAT).Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(ScalarShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(0, c.Rank(output));
|
||||
}
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({1, 23, 4, 4, 2})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(ScalarShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(0, c.Rank(output));
|
||||
@ -108,7 +108,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3, 4})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -117,7 +117,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown inner dimension for one
|
||||
InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, -1}), S({3, 4})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -126,7 +126,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Invalid rank.
|
||||
InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2}), S({3, 4})}, {}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -136,7 +136,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown outer dimension
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3, -1})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(MatMulShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -145,7 +145,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Inner shapes not compatible
|
||||
InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 5}), S({3, 4})}, {}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -156,7 +156,8 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
|
||||
{
|
||||
// Inner shapes not compatible
|
||||
InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 5, 3}), S({3, 5, 4})}, {}, {}, {},
|
||||
{});
|
||||
auto s = MatMulShape(&c);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
@ -174,7 +175,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Attr("type", DT_FLOAT)
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({3, 2}), S({3, 4})}, {}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -191,7 +192,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
|
||||
.Attr("type", DT_FLOAT)
|
||||
.Finalize(&def));
|
||||
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({4, 3})}, {}, {}, {}, {});
|
||||
auto s = MatMulShape(&c);
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -215,7 +216,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 10}), S({10})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(2, c.Value(c.Dim(output, 0)));
|
||||
@ -224,7 +225,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Unknown ranks.
|
||||
InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {});
|
||||
InferenceContext c(&def, op_def, {Unknown(), Unknown()}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_FALSE(c.RankKnown(output));
|
||||
@ -232,7 +233,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Rank > 2
|
||||
InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({4, 3, 4, 2, 15}), S({15})}, {}, {}, {},
|
||||
{});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[4,3,4,2,15]", c.DebugString(output));
|
||||
@ -245,7 +247,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3, 4, 5}), S({3})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[2,3,4,5]", c.DebugString(output));
|
||||
@ -258,8 +260,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {},
|
||||
{});
|
||||
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5}), S({3})}, {}, {},
|
||||
{}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
|
||||
@ -272,7 +274,8 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({10, 11, 12}), S({10})}, {}, {}, {},
|
||||
{});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[10,11,12]", c.DebugString(output));
|
||||
@ -280,7 +283,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({3}), S({3})}, {}, {}, {}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
|
||||
@ -292,7 +295,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3}), S({3})}, {}, {}, {}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
}
|
||||
@ -311,7 +314,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Finalize(&def));
|
||||
|
||||
{
|
||||
InferenceContext c(&def, op_def, {S({2, 10})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 10})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -319,7 +322,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
|
||||
{
|
||||
// Rank > 2
|
||||
InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({5, 7, 2, 10})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -331,7 +334,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3, 4, 5})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||
@ -343,7 +346,8 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({8, 6, 4, 2, 3, 4, 5})}, {}, {}, {},
|
||||
{});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||
@ -355,7 +359,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({10, 11, 12})}, {}, {}, {}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
@ -363,7 +367,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(&def, op_def, {S({3})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({3})}, {}, {}, {}, {});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
|
||||
@ -374,7 +378,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(&def, op_def, {S({2, 3})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({2, 3})}, {}, {}, {}, {});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
}
|
||||
|
@ -66,6 +66,14 @@ void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
|
||||
AttrValueMap::value_type(name.ToString(), attr_value));
|
||||
}
|
||||
|
||||
// Adds an attr to an attr value map.
|
||||
template <class T>
|
||||
void AddAttr(StringPiece name, T&& value, AttrValueMap* map) {
|
||||
AttrValue attr_value;
|
||||
SetAttrValue(value, &attr_value);
|
||||
map->insert(AttrValueMap::value_type(name.ToString(), attr_value));
|
||||
}
|
||||
|
||||
class AttrSlice {
|
||||
public:
|
||||
AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <typeinfo>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -31,7 +31,9 @@ InferenceContext::InferenceContext(
|
||||
const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<TensorShapeProto>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes)
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes,
|
||||
const std::vector<TensorShapeProto>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes)
|
||||
: node_def_(*CHECK_NOTNULL(node_def)) {
|
||||
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
|
||||
if (!construction_status_.ok()) return;
|
||||
@ -43,19 +45,30 @@ InferenceContext::InferenceContext(
|
||||
}
|
||||
inputs_.push_back(shape);
|
||||
}
|
||||
PostInputInit();
|
||||
std::vector<ShapeHandle> handle_shapes;
|
||||
for (const auto& p : input_handle_shapes) {
|
||||
ShapeHandle shape;
|
||||
construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
|
||||
if (!construction_status_.ok()) {
|
||||
return;
|
||||
}
|
||||
handle_shapes.push_back(shape);
|
||||
}
|
||||
PostInputInit(handle_shapes, input_handle_dtypes);
|
||||
}
|
||||
|
||||
InferenceContext::InferenceContext(
|
||||
const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes)
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes,
|
||||
const std::vector<ShapeHandle>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes)
|
||||
: node_def_(*CHECK_NOTNULL(node_def)) {
|
||||
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
|
||||
if (!construction_status_.ok()) return;
|
||||
inputs_ = input_shapes;
|
||||
PostInputInit();
|
||||
PostInputInit(input_handle_shapes, input_handle_dtypes);
|
||||
}
|
||||
|
||||
InferenceContext::~InferenceContext() {
|
||||
@ -124,15 +137,44 @@ void InferenceContext::PreInputInit(
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
outputs_.push_back(nullptr);
|
||||
}
|
||||
output_handle_shape_.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
output_handle_shape_.push_back(UnknownShape());
|
||||
}
|
||||
output_handle_dtype_ = std::vector<DataType>(num_outputs, DT_INVALID);
|
||||
}
|
||||
|
||||
void InferenceContext::PostInputInit() {
|
||||
void InferenceContext::PostInputInit(
|
||||
const std::vector<ShapeHandle>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes) {
|
||||
int num_inputs_from_node_def = 0;
|
||||
for (const auto& e : input_name_map_) {
|
||||
num_inputs_from_node_def =
|
||||
std::max(num_inputs_from_node_def, e.second.second);
|
||||
}
|
||||
|
||||
// Allow passing empty shapes/dtypes to avoid changing every single test.
|
||||
if (input_handle_shapes.empty()) {
|
||||
input_handle_shape_.resize(inputs_.size());
|
||||
} else {
|
||||
input_handle_shape_ = input_handle_shapes;
|
||||
if (input_handle_shape_.size() != inputs_.size()) {
|
||||
construction_status_ = errors::InvalidArgument(
|
||||
"Wrong number of handle shapes passed; expected ", inputs_.size(),
|
||||
" got ", input_handle_shape_.size());
|
||||
}
|
||||
}
|
||||
if (input_handle_dtypes.empty()) {
|
||||
input_handle_dtype_ = std::vector<DataType>(inputs_.size(), DT_INVALID);
|
||||
} else {
|
||||
input_handle_dtype_ = input_handle_dtypes;
|
||||
if (input_handle_dtype_.size() != inputs_.size()) {
|
||||
construction_status_ = errors::InvalidArgument(
|
||||
"Wrong number of handle dtypes passed; expected ", inputs_.size(),
|
||||
" got ", input_handle_dtype_.size());
|
||||
}
|
||||
}
|
||||
|
||||
if (inputs_.size() != num_inputs_from_node_def) {
|
||||
construction_status_ = errors::InvalidArgument(
|
||||
"Wrong number of inputs passed: ", inputs_.size(), " while ",
|
||||
@ -737,6 +779,13 @@ Status InferenceContext::AttachContext(const Status& status) {
|
||||
strings::StrCat(status.error_message(), error_context));
|
||||
}
|
||||
|
||||
ShapeHandle InferenceContext::input_handle_shape(int idx) {
|
||||
if (!input_handle_shape_[idx].IsSet()) {
|
||||
input_handle_shape_[idx] = UnknownShape();
|
||||
}
|
||||
return input_handle_shape_[idx];
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// ShapeManager
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -147,7 +147,9 @@ class InferenceContext {
|
||||
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<ShapeHandle>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes);
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes,
|
||||
const std::vector<ShapeHandle>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes);
|
||||
|
||||
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
|
||||
//
|
||||
@ -162,7 +164,9 @@ class InferenceContext {
|
||||
InferenceContext(const NodeDef* node_def, const OpDef& op_def,
|
||||
const std::vector<TensorShapeProto>& input_shapes,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes);
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes,
|
||||
const std::vector<TensorShapeProto>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes);
|
||||
|
||||
~InferenceContext();
|
||||
|
||||
@ -231,12 +235,12 @@ class InferenceContext {
|
||||
}
|
||||
return s->dims_[idx];
|
||||
}
|
||||
int32 Rank(ShapeHandle s) { return s->rank_; }
|
||||
bool RankKnown(ShapeHandle s) { return Rank(s) != kUnknownRank; }
|
||||
inline int64 Value(DimensionOrConstant d) {
|
||||
int32 Rank(ShapeHandle s) const { return s->rank_; }
|
||||
bool RankKnown(ShapeHandle s) const { return Rank(s) != kUnknownRank; }
|
||||
inline int64 Value(DimensionOrConstant d) const {
|
||||
return d.dim.IsSet() ? d.dim->value_ : d.val;
|
||||
}
|
||||
inline bool ValueKnown(DimensionOrConstant d) {
|
||||
inline bool ValueKnown(DimensionOrConstant d) const {
|
||||
return Value(d) != kUnknownDim;
|
||||
}
|
||||
|
||||
@ -391,6 +395,30 @@ class InferenceContext {
|
||||
|
||||
Status construction_status() const { return construction_status_; }
|
||||
|
||||
// Methods to propagate shape and dtype on edges of handles. Handles are the
|
||||
// dtype DT_RESOURCE which can be used to access state stored in a
|
||||
// ResourceManager. When ops (such as variables) consume these handles to
|
||||
// produce tensors they might need to know side-information about the shapes
|
||||
// and dtypes of tensors which can be accessed via the handle. These methods
|
||||
// propagate that information. Output handle dtypes and shapes are ignored if
|
||||
// the output tensor is not of type DT_RESOURCE.
|
||||
ShapeHandle input_handle_shape(int idx);
|
||||
DataType input_handle_dtype(int idx) const {
|
||||
return input_handle_dtype_[idx];
|
||||
}
|
||||
void set_output_handle_shape(int idx, ShapeHandle shape) {
|
||||
output_handle_shape_[idx] = shape;
|
||||
}
|
||||
void set_output_handle_dtype(int idx, DataType dtype) {
|
||||
output_handle_dtype_[idx] = dtype;
|
||||
}
|
||||
ShapeHandle output_handle_shape(int idx) const {
|
||||
return output_handle_shape_[idx];
|
||||
}
|
||||
DataType output_handle_dtype(int idx) const {
|
||||
return output_handle_dtype_[idx];
|
||||
}
|
||||
|
||||
// Validates the 3 component tensors of a sparse tensor have the proper
|
||||
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
|
||||
Status ValidateSparseTensor(ShapeHandle indices_shape,
|
||||
@ -481,7 +509,8 @@ class InferenceContext {
|
||||
void PreInputInit(const OpDef& op_def,
|
||||
const std::vector<const Tensor*>& input_tensors,
|
||||
const std::vector<ShapeHandle>& input_tensors_as_shapes);
|
||||
void PostInputInit();
|
||||
void PostInputInit(const std::vector<ShapeHandle>& input_handle_shapes,
|
||||
const std::vector<DataType>& input_handle_dtypes);
|
||||
|
||||
DimensionHandle GetDimension(const DimensionOrConstant& d);
|
||||
|
||||
@ -510,6 +539,11 @@ class InferenceContext {
|
||||
std::vector<ShapeHandle> input_tensors_as_shapes_;
|
||||
std::vector<bool> requested_input_tensor_as_partial_shape_;
|
||||
|
||||
std::vector<ShapeHandle> input_handle_shape_;
|
||||
std::vector<DataType> input_handle_dtype_;
|
||||
std::vector<ShapeHandle> output_handle_shape_;
|
||||
std::vector<DataType> output_handle_dtype_;
|
||||
|
||||
const NodeDef& node_def_;
|
||||
NameRangeMap input_name_map_;
|
||||
NameRangeMap output_name_map_;
|
||||
|
@ -71,7 +71,8 @@ TEST_F(ShapeInferenceTest, InputOutputByName) {
|
||||
.Attr("N", 3)
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(&def);
|
||||
InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {});
|
||||
InferenceContext c(&def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})}, {}, {},
|
||||
{}, {});
|
||||
|
||||
EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
|
||||
EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
|
||||
@ -107,7 +108,7 @@ static OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, DimensionOrConstant) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {}, {});
|
||||
EXPECT_EQ(InferenceContext::kUnknownDim,
|
||||
c.Value(InferenceContext::kUnknownDim));
|
||||
EXPECT_EQ(1, c.Value(1));
|
||||
@ -122,7 +123,7 @@ TEST_F(ShapeInferenceTest, Run) {
|
||||
NodeDef def;
|
||||
def.set_name("foo");
|
||||
def.set_op("foo_op");
|
||||
InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(3, 2), {S({1})}, {}, {}, {}, {});
|
||||
|
||||
{
|
||||
auto fn = [](InferenceContext* c) {
|
||||
@ -154,7 +155,7 @@ TEST_F(ShapeInferenceTest, Run) {
|
||||
TEST_F(ShapeInferenceTest, RankAndDimInspection) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 2), {Unknown(), S({1, -1, 3}), S({})},
|
||||
{}, {});
|
||||
{}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(2, c.num_outputs());
|
||||
|
||||
@ -195,7 +196,8 @@ TEST_F(ShapeInferenceTest, RankAndDimInspection) {
|
||||
TEST_F(ShapeInferenceTest, NumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 2),
|
||||
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {});
|
||||
{Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {},
|
||||
{});
|
||||
|
||||
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
|
||||
EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
|
||||
@ -208,7 +210,8 @@ TEST_F(ShapeInferenceTest, NumElements) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
|
||||
{}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -246,7 +249,8 @@ TEST_F(ShapeInferenceTest, WithRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
|
||||
{}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -284,7 +288,8 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})}, {}, {},
|
||||
{}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -322,7 +327,7 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, WithValue) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {}, {});
|
||||
|
||||
auto d0 = c.Dim(c.input(0), 0);
|
||||
auto d1 = c.Dim(c.input(0), 1);
|
||||
@ -363,7 +368,8 @@ TEST_F(ShapeInferenceTest, WithValue) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, MergeDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {}, {}, {},
|
||||
{});
|
||||
|
||||
auto d2 = c.Dim(c.input(0), 0);
|
||||
auto d_unknown = c.Dim(c.input(0), 1);
|
||||
@ -412,7 +418,7 @@ TEST_F(ShapeInferenceTest, MergeShape) {
|
||||
InferenceContext c(&def, MakeOpDef(7, 2),
|
||||
{Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
|
||||
Unknown(), S({1})},
|
||||
{}, {});
|
||||
{}, {}, {}, {});
|
||||
|
||||
auto s_unknown = c.input(0);
|
||||
auto s_1_2 = c.input(1);
|
||||
@ -483,7 +489,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
||||
{
|
||||
Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
|
||||
},
|
||||
{}, {});
|
||||
{}, {}, {}, {});
|
||||
|
||||
auto s_unknown = c.input(0);
|
||||
auto s_u_2 = c.input(1);
|
||||
@ -536,7 +542,7 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
|
||||
TEST_F(ShapeInferenceTest, Subshape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {S({1, 2, 3, -1, 5}), Unknown()},
|
||||
{}, {});
|
||||
{}, {}, {}, {});
|
||||
|
||||
ShapeHandle unknown = c.input(1);
|
||||
ShapeHandle out;
|
||||
@ -611,7 +617,7 @@ TEST_F(ShapeInferenceTest, Subshape) {
|
||||
TEST_F(ShapeInferenceTest, Concatenate) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 2),
|
||||
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {});
|
||||
{S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {}, {});
|
||||
|
||||
auto in0 = c.input(0);
|
||||
auto in1 = c.input(1);
|
||||
@ -637,7 +643,8 @@ TEST_F(ShapeInferenceTest, Concatenate) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ReplaceDim) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()}, {}, {},
|
||||
{}, {});
|
||||
|
||||
auto in = c.input(0);
|
||||
auto unknown = c.input(1);
|
||||
@ -668,7 +675,8 @@ TEST_F(ShapeInferenceTest, ReplaceDim) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, MakeShape) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {}, {}, {},
|
||||
{});
|
||||
|
||||
std::vector<DimensionHandle> dims;
|
||||
auto in0 = c.input(0);
|
||||
@ -693,7 +701,7 @@ TEST_F(ShapeInferenceTest, MakeShape) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShape) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
|
||||
auto u0 = c.UnknownShape();
|
||||
auto u1 = c.UnknownShape();
|
||||
@ -705,7 +713,7 @@ TEST_F(ShapeInferenceTest, UnknownShape) {
|
||||
TEST_F(ShapeInferenceTest, Scalar) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
|
||||
auto s0 = c.Scalar();
|
||||
EXPECT_EQ("[]", c.DebugString(s0));
|
||||
@ -716,7 +724,7 @@ TEST_F(ShapeInferenceTest, Scalar) {
|
||||
TEST_F(ShapeInferenceTest, Vector) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
|
||||
auto s0 = c.Vector(1);
|
||||
EXPECT_EQ("[1]", c.DebugString(s0));
|
||||
@ -732,7 +740,7 @@ TEST_F(ShapeInferenceTest, Vector) {
|
||||
TEST_F(ShapeInferenceTest, Matrix) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
|
||||
auto s0 = c.Matrix(1, 2);
|
||||
EXPECT_EQ("[1,2]", c.DebugString(s0));
|
||||
@ -754,7 +762,7 @@ TEST_F(ShapeInferenceTest, Matrix) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
auto create = [&](Tensor* t) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}, {});
|
||||
ShapeHandle out;
|
||||
Status s = c.MakeShapeFromShapeTensor(0, &out);
|
||||
if (s.ok()) {
|
||||
@ -806,7 +814,8 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
// Test when the input shape is wrong.
|
||||
{
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr}, {}, {},
|
||||
{});
|
||||
ShapeHandle out;
|
||||
EXPECT_EQ("Shape must be rank 1 but is rank 2",
|
||||
c.MakeShapeFromShapeTensor(0, &out).error_message());
|
||||
@ -816,7 +825,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
|
||||
TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
TensorShapeProto proto;
|
||||
|
||||
// With a set unknown rank.
|
||||
@ -852,7 +861,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
|
||||
TEST_F(ShapeInferenceTest, MakeDim) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
|
||||
auto d0 = c.MakeDim(1);
|
||||
auto d1 = c.MakeDim(1);
|
||||
@ -866,7 +875,7 @@ TEST_F(ShapeInferenceTest, MakeDim) {
|
||||
TEST_F(ShapeInferenceTest, UnknownDim) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
|
||||
auto d0 = c.UnknownDim();
|
||||
auto d1 = c.UnknownDim();
|
||||
@ -878,7 +887,7 @@ TEST_F(ShapeInferenceTest, UnknownDim) {
|
||||
TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
|
||||
auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
|
||||
EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
|
||||
@ -892,7 +901,7 @@ TEST_F(ShapeInferenceTest, InputTensors) {
|
||||
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
|
||||
{&t1, &t2}, {});
|
||||
{&t1, &t2}, {}, {}, {});
|
||||
|
||||
EXPECT_TRUE(c.input_tensor(0) == &t1);
|
||||
EXPECT_TRUE(c.input_tensor(1) == &t2);
|
||||
@ -903,7 +912,8 @@ TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
|
||||
Tensor t1 = tensorflow::test::AsScalar<int32>(20);
|
||||
Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {});
|
||||
InferenceContext c(&def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2}, {}, {},
|
||||
{});
|
||||
|
||||
DimensionHandle d;
|
||||
EXPECT_TRUE(c.MakeDimForScalarInput(0, &d).ok());
|
||||
@ -934,7 +944,7 @@ TEST_F(ShapeInferenceTest, GetAttr) {
|
||||
.ok());
|
||||
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, op_reg_data.op_def, empty, {}, {});
|
||||
InferenceContext c(&def, op_reg_data.op_def, empty, {}, {}, {}, {});
|
||||
string value;
|
||||
EXPECT_TRUE(c.GetAttr("foo", &value).ok());
|
||||
EXPECT_EQ("bar", value);
|
||||
@ -942,7 +952,8 @@ TEST_F(ShapeInferenceTest, GetAttr) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Divide) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {}, {}, {},
|
||||
{});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1004,7 +1015,7 @@ TEST_F(ShapeInferenceTest, Divide) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Add) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1055,7 +1066,7 @@ TEST_F(ShapeInferenceTest, Add) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Subtract) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1104,7 +1115,7 @@ TEST_F(ShapeInferenceTest, Subtract) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Multiply) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_6 = c.Dim(s, 0);
|
||||
@ -1157,7 +1168,7 @@ TEST_F(ShapeInferenceTest, Multiply) {
|
||||
TEST_F(ShapeInferenceTest, FullyDefined) {
|
||||
NodeDef def;
|
||||
std::vector<ShapeHandle> empty;
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(0, 2), empty, {}, {}, {}, {});
|
||||
|
||||
// No rank or missing dimension information should return false.
|
||||
EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
|
||||
@ -1170,7 +1181,7 @@ TEST_F(ShapeInferenceTest, FullyDefined) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Min) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_1 = c.Dim(s, 0);
|
||||
@ -1218,7 +1229,7 @@ TEST_F(ShapeInferenceTest, Min) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, Max) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {});
|
||||
InferenceContext c(&def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {}, {}, {});
|
||||
|
||||
auto s = c.input(0);
|
||||
auto d_1 = c.Dim(s, 0);
|
||||
@ -1256,7 +1267,7 @@ TEST_F(ShapeInferenceTest, Max) {
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {Unknown(), Unknown(), Unknown()},
|
||||
{}, {});
|
||||
{}, {}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1269,7 +1280,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapes) {
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, -1}), S({-1}), S({-1})}, {},
|
||||
{});
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1281,8 +1292,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownDims) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {},
|
||||
{});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1}), S({-1}), S({-1})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1295,8 +1306,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidIndicesRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {},
|
||||
{});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({4}), S({3})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1309,8 +1320,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidNumElements) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {},
|
||||
{});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({4})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1324,7 +1335,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_InvalidRank) {
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({-1, 3}), S({5}), S({3})}, {},
|
||||
{});
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1337,7 +1348,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumIndexElements) {
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({-1}), S({3})}, {},
|
||||
{});
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1350,7 +1361,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownNumValueElements) {
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, -1}), S({5}), S({3})}, {},
|
||||
{});
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1363,7 +1374,7 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownIndexRank) {
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({-1})}, {},
|
||||
{});
|
||||
{}, {}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
@ -1375,8 +1386,8 @@ TEST_F(ShapeInferenceTest, ValidateSparseTensor_UnknownShapeRank) {
|
||||
|
||||
TEST_F(ShapeInferenceTest, ValidateSparseTensor) {
|
||||
NodeDef def;
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {},
|
||||
{});
|
||||
InferenceContext c(&def, MakeOpDef(3, 1), {S({5, 3}), S({5}), S({3})}, {}, {},
|
||||
{}, {});
|
||||
EXPECT_EQ(3, c.num_inputs());
|
||||
EXPECT_EQ(1, c.num_outputs());
|
||||
|
||||
|
@ -44,8 +44,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op,
|
||||
}
|
||||
|
||||
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
|
||||
in_shapes, op.input_tensors,
|
||||
{} /* input_tensors_as_shapes */);
|
||||
in_shapes, op.input_tensors, {}, {}, {});
|
||||
TF_RETURN_IF_ERROR(c.construction_status());
|
||||
if (op_reg_data->shape_inference_fn == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
|
@ -435,6 +435,7 @@ class Tensor {
|
||||
friend class VariableOp; // For access to set_shape
|
||||
friend class AutoReloadVariableOp; // For access to set_shape
|
||||
friend class TensorTestHelper; // For access to set_shape
|
||||
friend class CreateVariableOp;
|
||||
|
||||
// Creates a tensor with the input datatype, shape and buf.
|
||||
//
|
||||
|
@ -1030,6 +1030,18 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "resource_variable_ops",
|
||||
srcs = ["resource_variable_ops.cc"],
|
||||
deps = [
|
||||
":variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "fact_op",
|
||||
prefix = "fact_op",
|
||||
|
49
tensorflow/core/kernels/resource_variable_ops.cc
Normal file
49
tensorflow/core/kernels/resource_variable_ops.cc
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
|
||||
|
||||
class CreateVariableOp : public OpKernel {
|
||||
public:
|
||||
CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
Var* var = new Var(dtype_);
|
||||
var->Ref();
|
||||
core::ScopedUnref ur(var);
|
||||
OP_REQUIRES_OK(c, CreateResource<Var>(c, HandleFromInput(c, 0), var));
|
||||
// TODO(apassos): this currently does not initialize the tensor, so it's
|
||||
// pointless, other than checking construction in tests. Fix this.
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateVariableOp").Device(DEVICE_CPU),
|
||||
CreateVariableOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -26,6 +26,26 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Resource stored by variables in the resource manager.
|
||||
class Var : public ResourceBase {
|
||||
public:
|
||||
explicit Var(DataType dtype) : tensor_(dtype) {}
|
||||
mutex* mu() { return &mu_; }
|
||||
Tensor* tensor() { return &tensor_; }
|
||||
|
||||
string DebugString() override {
|
||||
return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
|
||||
tensor_.shape().DebugString());
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
Tensor tensor_;
|
||||
|
||||
~Var() override {}
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Var);
|
||||
};
|
||||
|
||||
class VariableOp : public OpKernel {
|
||||
public:
|
||||
explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
@ -59,25 +79,6 @@ class VariableOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
class Var : public ResourceBase {
|
||||
public:
|
||||
explicit Var(DataType dtype) : tensor_(dtype) {}
|
||||
mutex* mu() { return &mu_; }
|
||||
Tensor* tensor() { return &tensor_; }
|
||||
|
||||
string DebugString() override {
|
||||
return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
|
||||
tensor_.shape().DebugString());
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
Tensor tensor_;
|
||||
|
||||
~Var() override {}
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Var);
|
||||
};
|
||||
|
||||
DataType dtype_;
|
||||
TensorShape shape_;
|
||||
|
||||
|
@ -1252,7 +1252,12 @@ REGISTER_OP("Identity")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
|
||||
c->set_output_handle_shape(0, c->input_handle_shape(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"Doc(
|
||||
Return a tensor with the same shape and contents as the input tensor or value.
|
||||
)Doc");
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -153,6 +155,21 @@ TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) {
|
||||
INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
|
||||
const char* op_name = "Identity";
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
// Check that handle dtypes are preserved.
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
|
||||
shape_inference::InferenceContext c(&op.node_def, op_reg_data->op_def,
|
||||
{TensorShapeProto()}, {}, {}, {},
|
||||
{DT_BOOL});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
|
||||
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
|
||||
EXPECT_TRUE(c.output_handle_dtype(0) == DT_BOOL);
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, Diag_ShapeFn) {
|
||||
ShapeInferenceTestOp op("Diag");
|
||||
INFER_OK(op, "?", "?");
|
||||
|
@ -911,6 +911,19 @@ REGISTER_OP("Select")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Merge handle shape and dtype if applicable.
|
||||
if (c->input_handle_dtype(1) != c->input_handle_dtype(2)) {
|
||||
// TODO(apassos) resolve this in the manner of b/32476923
|
||||
return errors::InvalidArgument(
|
||||
"Trying to merge handles pointing to different dtypes.");
|
||||
}
|
||||
c->set_output_handle_dtype(0, c->input_handle_dtype(1));
|
||||
ShapeHandle output_handle_shape;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input_handle_shape(1),
|
||||
c->input_handle_shape(2),
|
||||
&output_handle_shape));
|
||||
c->set_output_handle_shape(0, output_handle_shape);
|
||||
|
||||
// The inputs 'then' and 'else' must have the same shape.
|
||||
ShapeHandle data = c->input(1);
|
||||
ShapeHandle other = c->input(2);
|
||||
@ -961,8 +974,9 @@ REGISTER_OP("Select")
|
||||
}
|
||||
|
||||
c->set_output(0, data);
|
||||
|
||||
return Status::OK();
|
||||
})
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Selects elements from `t` or `e`, depending on `condition`.
|
||||
|
||||
|
@ -216,6 +216,37 @@ TEST(MathOpsTest, Select_ShapeFn) {
|
||||
INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
|
||||
INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
|
||||
"[2,?,5];[?,?,3];[?,2,?]");
|
||||
|
||||
// Test that handle shapes were merged.
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
|
||||
TensorShapeProto i0;
|
||||
i0.add_dim()->set_size(1);
|
||||
i0.add_dim()->set_size(-1);
|
||||
TensorShapeProto i1;
|
||||
i1.add_dim()->set_size(-1);
|
||||
i1.add_dim()->set_size(2);
|
||||
|
||||
ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
|
||||
shape_inference::InferenceContext c(
|
||||
&op.node_def, op_reg_data->op_def,
|
||||
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
|
||||
{TensorShapeProto(), i0, i1}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
|
||||
EXPECT_TRUE(c.FullyDefined(c.output_handle_shape(0)));
|
||||
EXPECT_EQ("[1,2]", c.DebugString(c.output_handle_shape(0)));
|
||||
|
||||
// Expect an error when the shapes can't be merged.
|
||||
TensorShapeProto i2;
|
||||
i1.add_dim()->set_size(2);
|
||||
i1.add_dim()->set_size(2);
|
||||
shape_inference::InferenceContext c2(
|
||||
&op.node_def, op_reg_data->op_def,
|
||||
{TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
|
||||
{TensorShapeProto(), i0, i2}, {});
|
||||
TF_ASSERT_OK(c.construction_status());
|
||||
EXPECT_FALSE(c2.Run(op_reg_data->shape_inference_fn).ok());
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, Range_ShapeFn) {
|
||||
|
80
tensorflow/core/ops/resource_variable_ops.cc
Normal file
80
tensorflow/core/ops/resource_variable_ops.cc
Normal file
@ -0,0 +1,80 @@
|
||||
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// ============================================================================
|
||||
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("VarHandleOp")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("dtype: type")
|
||||
.Attr("shape: shape")
|
||||
.Output("resource: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
DataType t;
|
||||
c->GetAttr("dtype", &t);
|
||||
c->set_output_handle_dtype(0, t);
|
||||
TensorShapeProto p;
|
||||
c->GetAttr("shape", &p);
|
||||
shape_inference::ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(p, &s));
|
||||
c->set_output_handle_shape(0, s);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"(
|
||||
Creates a handle to a Variable resource.
|
||||
|
||||
container: the container this variable is placed in.
|
||||
shared_name: the name by which this variable is referred to.
|
||||
dtype: the type of this variable. Must agree with the dtypes
|
||||
of all ops using this variable.
|
||||
shape: The (possibly partially specified) shape of this variable.
|
||||
)");
|
||||
|
||||
REGISTER_OP("CreateVariableOp")
|
||||
.Input("resource: resource")
|
||||
.Input("value: dtype")
|
||||
.Attr("dtype: type")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
DataType handle_dtype = c->input_handle_dtype(0);
|
||||
DataType value_dtype;
|
||||
c->GetAttr("dtype", &value_dtype);
|
||||
if (handle_dtype != value_dtype) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to initialize handle for variable with wrong dtype. "
|
||||
"Expected ",
|
||||
handle_dtype, " got ", value_dtype);
|
||||
}
|
||||
shape_inference::ShapeHandle s = c->input_handle_shape(0);
|
||||
shape_inference::ShapeHandle value_shape = c->input(1);
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"(
|
||||
Creates a variable resource.
|
||||
|
||||
resource: handle to the resource in which to store the variable.
|
||||
value: the value to set the new tensor to use.
|
||||
dtype: the dtype of the value.
|
||||
)");
|
||||
|
||||
} // namespace tensorflow
|
@ -217,23 +217,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpp_shape_inference",
|
||||
srcs = ["framework/cpp_shape_inference.cc"],
|
||||
hdrs = ["framework/cpp_shape_inference.h"],
|
||||
copts = ["-Wno-sign-compare"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":numpy_lib",
|
||||
":py_func_lib",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_cc",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "python_op_gen_main",
|
||||
srcs = [
|
||||
@ -284,6 +267,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cpp_shape_inference_proto_py",
|
||||
":framework_for_generated_wrappers",
|
||||
":pywrap_tensorflow",
|
||||
],
|
||||
@ -660,6 +644,11 @@ tf_gen_op_wrapper_private_py(
|
||||
require_shape_functions = True,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "resource_variable_ops_gen",
|
||||
require_shape_functions = True,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "script_ops_gen",
|
||||
require_shape_functions = True,
|
||||
@ -989,6 +978,16 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "resource_variable_ops",
|
||||
srcs = ["ops/resource_variable_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":framework",
|
||||
":resource_variable_ops_gen",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "nn",
|
||||
srcs = ["ops/nn.py"],
|
||||
@ -1409,6 +1408,7 @@ py_library(
|
||||
":partitioned_variables",
|
||||
":random_ops",
|
||||
":random_ops_gen",
|
||||
":resource_variable_ops",
|
||||
":resources",
|
||||
":rnn",
|
||||
":rnn_cell",
|
||||
@ -1690,6 +1690,7 @@ tf_proto_library(
|
||||
["**/*.proto"],
|
||||
exclude = [
|
||||
"util/protobuf/compare_test.proto",
|
||||
"framework/cpp_shape_inference.proto",
|
||||
],
|
||||
),
|
||||
go_api_version = 2,
|
||||
@ -1701,6 +1702,13 @@ tf_proto_library_py(
|
||||
srcs = ["util/protobuf/compare_test.proto"],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "cpp_shape_inference_proto",
|
||||
srcs = ["framework/cpp_shape_inference.proto"],
|
||||
cc_api_version = 2,
|
||||
cc_libs = ["//tensorflow/core:protos_all_cc"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "protobuf_compare_test",
|
||||
size = "small",
|
||||
@ -1767,6 +1775,24 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpp_shape_inference",
|
||||
srcs = ["framework/cpp_shape_inference.cc"],
|
||||
hdrs = ["framework/cpp_shape_inference.h"],
|
||||
copts = ["-Wno-sign-compare"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cpp_shape_inference_proto_cc",
|
||||
":numpy_lib",
|
||||
":py_func_lib",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_cc",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "device_lib_test",
|
||||
size = "small",
|
||||
|
@ -20,8 +20,8 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import six.moves
|
||||
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -567,8 +567,12 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
|
||||
the C++ shape function.
|
||||
|
||||
Returns:
|
||||
A TensorShape list of the output shapes of the op, as computed using the
|
||||
C++ shape inference function registered for the op.
|
||||
A dictionary with the following keys:
|
||||
shapes: A TensorShape list of the output shapes of the op, as computed
|
||||
using the C++ shape inference function registered for the op.
|
||||
handle_shapes: A TensorShape list of the shapes for handle outputs, if
|
||||
any.
|
||||
handle_dtypes: A list of DataType enums for the handle outputs, if any.
|
||||
|
||||
Raises:
|
||||
ValueError: If the C++ shape function returned an error (e.g. because the
|
||||
@ -576,8 +580,16 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
|
||||
according to the shape function).
|
||||
"""
|
||||
node_def_str = op.node_def.SerializeToString()
|
||||
input_shapes = [i.get_shape().as_proto().SerializeToString() for i in
|
||||
op.inputs]
|
||||
|
||||
def tensor_to_inference_result(t):
|
||||
r = cpp_shape_inference_pb2.CppShapeInferenceResult()
|
||||
r.shape.CopyFrom(t.get_shape().as_proto())
|
||||
# pylint: disable=protected-access
|
||||
r.handle_shape.CopyFrom(t._handle_shape)
|
||||
r.handle_dtype = t._handle_dtype
|
||||
# pylint: enable=protected-access
|
||||
return r.SerializeToString()
|
||||
input_shapes = [tensor_to_inference_result(i) for i in op.inputs]
|
||||
|
||||
input_tensors = [None for i in input_shapes]
|
||||
if input_tensors_needed:
|
||||
@ -596,10 +608,13 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
|
||||
raise ValueError(err.message)
|
||||
|
||||
# Convert TensorShapeProto values in output_shapes.
|
||||
result = [
|
||||
tensor_shape.TensorShape(tensor_shape_pb2.TensorShapeProto.FromString(s))
|
||||
result_protos = [
|
||||
cpp_shape_inference_pb2.CppShapeInferenceResult().FromString(s)
|
||||
for s in output_shapes
|
||||
]
|
||||
result = [r.shape for r in result_protos]
|
||||
result_handle_shapes = [r.handle_shape for r in result_protos]
|
||||
result_handle_dtypes = [r.handle_dtype for r in result_protos]
|
||||
|
||||
if debug_python_shape_fn:
|
||||
try:
|
||||
@ -616,4 +631,6 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
|
||||
str(result), str(python_result), str(op.node_def),
|
||||
",".join([str(i.get_shape()) for i in op.inputs])))
|
||||
|
||||
return result
|
||||
return {"shapes": result,
|
||||
"handle_shapes": result_handle_shapes,
|
||||
"handle_dtypes": result_handle_dtypes}
|
||||
|
@ -20,12 +20,32 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
|
||||
#include "tensorflow/python/lib/core/py_func.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
namespace {
|
||||
|
||||
void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s,
|
||||
tensorflow::shape_inference::InferenceContext* c,
|
||||
TensorShapeProto* out) {
|
||||
if (c->RankKnown(s)) {
|
||||
const int32 rank = c->Rank(s);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
shape_inference::DimensionHandle d = c->Dim(s, i);
|
||||
auto* out_dim = out->add_dim();
|
||||
if (c->ValueKnown(d)) {
|
||||
out_dim->set_size(c->Value(d));
|
||||
} else {
|
||||
out_dim->set_size(-1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
out->set_unknown_rank(true);
|
||||
}
|
||||
}
|
||||
|
||||
Status RunCppShapeInferenceImpl(
|
||||
const string& serialized_node_def,
|
||||
const std::vector<string>& input_serialized_shapes,
|
||||
@ -49,12 +69,21 @@ Status RunCppShapeInferenceImpl(
|
||||
|
||||
// Convert input shapes.
|
||||
std::vector<TensorShapeProto> input_shapes;
|
||||
std::vector<TensorShapeProto> input_handle_shapes;
|
||||
std::vector<DataType> input_handle_dtypes;
|
||||
input_shapes.resize(input_serialized_shapes.size());
|
||||
input_handle_shapes.resize(input_serialized_shapes.size());
|
||||
input_handle_dtypes.resize(input_serialized_shapes.size());
|
||||
CppShapeInferenceResult tmp;
|
||||
for (int i = 0; i < input_serialized_shapes.size(); ++i) {
|
||||
if (!input_shapes[i].ParseFromString(input_serialized_shapes[i])) {
|
||||
tmp.Clear();
|
||||
if (!tmp.ParseFromString(input_serialized_shapes[i])) {
|
||||
return errors::InvalidArgument(
|
||||
"Error parsing shape proto during cpp shape inference");
|
||||
}
|
||||
input_shapes[i].Swap(tmp.mutable_shape());
|
||||
input_handle_dtypes[i] = tmp.handle_dtype();
|
||||
input_handle_shapes[i].Swap(tmp.mutable_handle_shape());
|
||||
}
|
||||
|
||||
// Convert input tensor values;
|
||||
@ -73,34 +102,23 @@ Status RunCppShapeInferenceImpl(
|
||||
}
|
||||
|
||||
// Run shape inference.
|
||||
// TODO(cwhipkey): pass a value for input_tensors_as_shapes.
|
||||
tensorflow::shape_inference::InferenceContext c(
|
||||
&node, op_reg_data->op_def, input_shapes, input_tensors,
|
||||
{} /* input_tensors_as_shapes */);
|
||||
{} /* input_tensors_as_shapes */, input_handle_shapes,
|
||||
input_handle_dtypes);
|
||||
TF_RETURN_IF_ERROR(c.construction_status());
|
||||
|
||||
TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
|
||||
|
||||
// Convert output shapes.
|
||||
output_tensor_shape_protos->resize(c.num_outputs());
|
||||
TensorShapeProto out;
|
||||
CppShapeInferenceResult out;
|
||||
for (int i = 0; i < c.num_outputs(); ++i) {
|
||||
shape_inference::ShapeHandle s = c.output(i);
|
||||
out.Clear();
|
||||
if (c.RankKnown(s)) {
|
||||
const int32 rank = c.Rank(s);
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
shape_inference::DimensionHandle d = c.Dim(s, i);
|
||||
auto* out_dim = out.add_dim();
|
||||
if (c.ValueKnown(d)) {
|
||||
out_dim->set_size(c.Value(d));
|
||||
} else {
|
||||
out_dim->set_size(-1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
out.set_unknown_rank(true);
|
||||
}
|
||||
ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape());
|
||||
ProtoFromShapeHandle(c.output_handle_shape(i), &c,
|
||||
out.mutable_handle_shape());
|
||||
out.set_handle_dtype(c.output_handle_dtype(i));
|
||||
CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ namespace swig {
|
||||
// inference was successful.
|
||||
//
|
||||
// On success, <*output_shapes> is populated with the inferred output shapes (as
|
||||
// serialized TensorShapeProtos).
|
||||
// serialized CppShapeInferenceResult protos).
|
||||
// <*output_shapes> must be empty when this function is called.
|
||||
//
|
||||
// This is temporary code to be used during the migration
|
||||
|
13
tensorflow/python/framework/cpp_shape_inference.proto
Normal file
13
tensorflow/python/framework/cpp_shape_inference.proto
Normal file
@ -0,0 +1,13 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
|
||||
message CppShapeInferenceResult {
|
||||
TensorShapeProto shape = 1;
|
||||
TensorShapeProto handle_shape = 2;
|
||||
DataType handle_dtype = 3;
|
||||
}
|
@ -32,6 +32,8 @@ from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import node_def_pb2
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -297,6 +299,10 @@ class Tensor(object):
|
||||
# to easily navigate a computation graph.
|
||||
self._consumers = []
|
||||
|
||||
# Attributes used for C++ shape inference. Not inspected, only forwarded.
|
||||
self._handle_shape = tensor_shape_pb2.TensorShapeProto()
|
||||
self._handle_dtype = types_pb2.DT_INVALID
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
"""The `Operation` that produces this tensor as an output."""
|
||||
@ -1791,10 +1797,22 @@ def set_shapes_for_outputs(op):
|
||||
if shapes is None:
|
||||
raise RuntimeError(
|
||||
"Shape function for op %s did not return any shapes" % op)
|
||||
elif isinstance(shapes, dict):
|
||||
# Returned by call_cpp_shape_fn
|
||||
shapes_dict = shapes
|
||||
shapes = shapes_dict["shapes"]
|
||||
handle_shapes = shapes_dict["handle_shapes"]
|
||||
handle_dtypes = shapes_dict["handle_dtypes"]
|
||||
for output, handle_shape, handle_dtype in zip(op.outputs, handle_shapes, handle_dtypes):
|
||||
# pylint: disable=protected-access
|
||||
output._handle_shape = handle_shape
|
||||
output._handle_dtype = handle_dtype
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if len(op.outputs) != len(shapes):
|
||||
raise RuntimeError(
|
||||
"Shape function for op %s returned %d shapes but expected %d" %
|
||||
(op, len(shapes), len(op.outputs)))
|
||||
"Shape function for op %s returned %d shapes but expected %d %s %s" %
|
||||
(op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
|
||||
for output, s in zip(op.outputs, shapes):
|
||||
output.set_shape(s)
|
||||
|
||||
|
@ -255,6 +255,16 @@ tf_py_test(
|
||||
additional_deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "resource_variable_ops_test",
|
||||
size = "small",
|
||||
srcs = ["resource_variable_ops_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "save_restore_ops_test",
|
||||
size = "small",
|
||||
|
51
tensorflow/python/kernel_tests/resource_variable_ops_test.py
Normal file
51
tensorflow/python/kernel_tests/resource_variable_ops_test.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.resource_variable_ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testHandleDtypeShapeMatch(self):
|
||||
with self.test_session():
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
with self.assertRaises(ValueError):
|
||||
resource_variable_ops.create_variable_op(
|
||||
handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
|
||||
with self.assertRaises(ValueError):
|
||||
resource_variable_ops.create_variable_op(
|
||||
handle, constant_op.constant([0], dtype=dtypes.int32)).run()
|
||||
resource_variable_ops.create_variable_op(
|
||||
handle, constant_op.constant(0, dtype=dtypes.int32)).run()
|
||||
|
||||
def testDtypeSurvivesIdentity(self):
|
||||
with self.test_session():
|
||||
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
|
||||
id_handle = array_ops.identity(handle)
|
||||
resource_variable_ops.create_variable_op(
|
||||
id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
30
tensorflow/python/ops/resource_variable_ops.py
Normal file
30
tensorflow/python/ops/resource_variable_ops.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops to use variables as resources."""
|
||||
|
||||
# pylint: disable=g-bad-name
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import ops
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_resource_variable_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
ops.RegisterShape("VarHandleOp")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("CreateVariableOp")(common_shapes.call_cpp_shape_fn)
|
Loading…
Reference in New Issue
Block a user