C++ API: run shape inference as nodes are constructed

Here's an example of the new generated code:

AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) {
  if (!scope.ok()) return;
  auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs);
  if (!scope.ok()) return;
  ::tensorflow::Node* ret;
  const auto unique_name = scope.GetUniqueNameForOp("AddN");
  auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN")
                     .Input(_inputs)
  ;
  scope.UpdateBuilder(&builder);
  scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
  if (!scope.ok()) return;
  scope.UpdateStatus(scope.DoShapeInference(ret));
  this->sum = Output(ret, 0);
}

Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it.

PiperOrigin-RevId: 165378429
This commit is contained in:
Skye Wanderman-Milne 2017-08-15 16:46:16 -07:00 committed by TensorFlower Gardener
parent 9ba0abc2f0
commit 477d49c9ea
26 changed files with 143 additions and 50 deletions

View File

@ -812,12 +812,8 @@ string OpInfo::GetConstructorBody() const {
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(", strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
scope_str, ".graph(), &ret));\n"); scope_str, ".graph(), &ret));\n");
strings::StrAppend(&body, " ", return_on_error, "\n"); strings::StrAppend(&body, " ", return_on_error, "\n");
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
// TODO(b/28152992): Enable this code-path once we have converted ".DoShapeInference(ret));\n");
// all python shape functions to call their C++ versions.
// strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
// ".refiner()->AddNode(ret));\n");
GetOutput(&body); GetOutput(&body);
return body; return body;

View File

@ -37,13 +37,14 @@ Scope& Scope::operator=(const Scope& other) {
} }
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
ShapeRefiner* refiner) ShapeRefiner* refiner, bool disable_shape_inference)
: graph_(graph), : graph_(graph),
status_(status), status_(status),
name_map_(name_map), name_map_(name_map),
refiner_(refiner), refiner_(refiner),
scope_used_(nullptr), scope_used_(nullptr),
colocation_constraints_() {} colocation_constraints_(),
disable_shape_inference_(disable_shape_inference) {}
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph, Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Status>& status, const std::shared_ptr<Status>& status,
@ -54,13 +55,23 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
name_map_(name_map), name_map_(name_map),
refiner_(refiner), refiner_(refiner),
scope_used_(nullptr), scope_used_(nullptr),
colocation_constraints_() {} colocation_constraints_(),
disable_shape_inference_(false) {}
Scope Scope::NewRootScope() { Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global()); Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner = ShapeRefiner* refiner =
new ShapeRefiner(graph->versions(), graph->op_registry()); new ShapeRefiner(graph->versions(), graph->op_registry());
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner)); return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
/* disable_shape_inference */ false));
}
Scope Scope::DisabledShapeInferenceScope() {
Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner =
new ShapeRefiner(graph->versions(), graph->op_registry());
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
/* disable_shape_inference */ true));
} }
Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
@ -77,7 +88,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_), exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_), kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_), device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {} colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
const string& op_name) const string& op_name)
@ -92,7 +104,8 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_), exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_), kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_), device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {} colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
std::vector<Operation> control_deps, bool clear_control_deps) std::vector<Operation> control_deps, bool clear_control_deps)
@ -113,7 +126,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_), exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_), kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_), device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {} colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device) Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
: graph_(other.impl()->graph_), : graph_(other.impl()->graph_),
@ -127,7 +141,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
exit_on_error_(other.impl()->exit_on_error_), exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_), kernel_label_(other.impl()->kernel_label_),
device_(device), device_(device),
colocation_constraints_(other.impl()->colocation_constraints_) {} colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
const string& op_name) const string& op_name)
@ -142,7 +157,8 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_), exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_), kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_), device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {} colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
: graph_(other.impl()->graph_), : graph_(other.impl()->graph_),
@ -156,7 +172,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true), exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_), kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_), device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {} colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
const string& kernel_label) const string& kernel_label)
@ -171,7 +188,8 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_), exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label), kernel_label_(kernel_label),
device_(other.impl()->device_), device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {} colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::Colocate, Scope::Impl::Impl(const Scope& other, Tags::Colocate,
const Operation& colocate_with_op, bool clear_colocations) const Operation& colocate_with_op, bool clear_colocations)
@ -189,7 +207,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
colocation_constraints_( colocation_constraints_(
clear_colocations clear_colocations
? std::unordered_set<string>() ? std::unordered_set<string>()
: other.impl()->GetColocationConstraints(colocate_with_op)) {} : other.impl()->GetColocationConstraints(colocate_with_op)),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
std::unordered_set<string> Scope::Impl::GetColocationConstraints( std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const { const Operation& colocate_with_op) const {
@ -404,6 +423,11 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
} }
} }
Status Scope::DoShapeInference(Node* node) const {
if (impl_->disable_shape_inference_) return Status::OK();
return impl_->refiner_->AddNode(node);
}
class InternalScope { class InternalScope {
public: public:
// NewScope doesn't take ownership of the inputs. // NewScope doesn't take ownership of the inputs.

View File

@ -199,6 +199,18 @@ class Scope {
// edges from the source and to the sink node, resolves back edges // edges from the source and to the sink node, resolves back edges
// by name), and makes sure the resulting graph is valid. // by name), and makes sure the resulting graph is valid.
Status ToGraph(Graph* g) const; Status ToGraph(Graph* g) const;
// Calls AddNode() using this scope's ShapeRefiner. This exists in the public
// API to prevent custom op wrappers from needing access to shape_refiner.h or
// scope_internal.h.
// TODO(skyewm): remove this from public API
Status DoShapeInference(Node* node) const;
// Creates a new root scope that causes all DoShapeInference() calls to return
// Status::OK() (on the returned scope and any subscopes). Used for testing.
// TODO(skyewm): fix tests that still require this and eventually remove, or
// at least remove from public API
static Scope DisabledShapeInferenceScope();
// END_SKIP_DOXYGEN // END_SKIP_DOXYGEN
const std::vector<Operation>& control_deps() const; const std::vector<Operation>& control_deps() const;

View File

@ -58,7 +58,8 @@ class Scope::Impl {
enum class Colocate; enum class Colocate;
}; };
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner); Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
bool disable_shape_inference);
Impl(const Scope& other, Tags::ScopeName, const string& name, Impl(const Scope& other, Tags::ScopeName, const string& name,
bool copy_names); bool copy_names);
Impl(const Scope& other, Tags::OpName, const string& name, Impl(const Scope& other, Tags::OpName, const string& name,
@ -103,6 +104,10 @@ class Scope::Impl {
const string kernel_label_ = ""; const string kernel_label_ = "";
const string device_ = ""; const string device_ = "";
const std::unordered_set<string> colocation_constraints_; const std::unordered_set<string> colocation_constraints_;
// If true, Scope::DoShapeInference() always returns Status:OK().
// TODO(skyewm): remove this when possible
const bool disable_shape_inference_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
namespace tensorflow { namespace tensorflow {
@ -24,6 +25,7 @@ REGISTER_OP("ThrowAway1")
.Attr("scope: int") .Attr("scope: int")
.Attr("builder: int = 1") .Attr("builder: int = 1")
.Attr("while: int") .Attr("while: int")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
Op to test keywords and reserved words in input and attr names. Op to test keywords and reserved words in input and attr names.
@ -36,12 +38,20 @@ REGISTER_OP("ThrowAway2")
.Attr("scope: int = 2") .Attr("scope: int = 2")
.Attr("throw_away2: int = 2") .Attr("throw_away2: int = 2")
.Attr("attrs: int = 4") .Attr("attrs: int = 4")
.Attr("node: int = 4"); .Attr("node: int = 4")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("ThrowAway3").Output("node: int32"); REGISTER_OP("ThrowAway3")
.Output("node: int32")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("ThrowAway4").Input("node: int32"); REGISTER_OP("ThrowAway4")
.Input("node: int32")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4"); REGISTER_OP("ThrowAway5")
.Output("foo: int32")
.Attr("node: int = 4")
.SetShapeFn(shape_inference::UnknownShape);
} // namespace tensorflow } // namespace tensorflow

View File

@ -34,7 +34,9 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
.Attr("dtype", val.tensor.dtype()); .Attr("dtype", val.tensor.dtype());
scope.UpdateBuilder(&builder); scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(graph, &ret)); scope.UpdateStatus(builder.Finalize(graph, &ret));
if (!scope.ok()) return Output();
scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return Output(); if (!scope.ok()) return Output();
return Output(ret); return Output(ret);

View File

@ -56,6 +56,8 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
scope.UpdateBuilder(&cast_builder); scope.UpdateBuilder(&cast_builder);
Node* ret; Node* ret;
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret)); scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return Output();
scope.UpdateStatus(scope.DoShapeInference(ret));
return Output(ret, 0); return Output(ret, 0);
} }

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
namespace tensorflow { namespace tensorflow {
@ -26,6 +27,7 @@ REGISTER_OP("XlaWhile")
.Attr("cond: func") .Attr("cond: func")
.Attr("body: func") .Attr("body: func")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
output = input; While (Cond(output)) { output = Body(output) } output = input; While (Cond(output)) { output = Body(output) }

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
namespace tensorflow { namespace tensorflow {
@ -22,6 +23,7 @@ REGISTER_OP("_XLASend")
.Attr("T: type") .Attr("T: type")
.Attr("tensor_name: string") .Attr("tensor_name: string")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
Sends the named tensor to another XLA computation. Sends the named tensor to another XLA computation.
@ -35,6 +37,7 @@ REGISTER_OP("_XLARecv")
.Attr("tensor_name: string") .Attr("tensor_name: string")
.Attr("shape: shape") .Attr("shape: shape")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
Receives the named tensor from another XLA computation. Receives the named tensor from another XLA computation.

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
@ -76,6 +77,8 @@ class DummyReadResourceCC {
scope.UpdateBuilder(&builder); scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return; if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return;
this->output_ = Output(ret, 0); this->output_ = Output(ret, 0);
} }
Node* node() const { return output_.node(); } Node* node() const { return output_.node(); }
@ -86,6 +89,7 @@ class DummyReadResourceCC {
REGISTER_OP("DummyReadResource") REGISTER_OP("DummyReadResource")
.Input("input: int32") .Input("input: int32")
.Output("output: int32") .Output("output: int32")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
A dummy Op. A dummy Op.

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
@ -67,6 +68,7 @@ REGISTER_OP("TPUReplicate")
.Input("broadcast_inputs: Tbroadcast_inputs") .Input("broadcast_inputs: Tbroadcast_inputs")
.Input("variables: NumVariables * resource") .Input("variables: NumVariables * resource")
.Output("outputs: output_types") .Output("outputs: output_types")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
Runs replicated computations on a distributed TPU system. Runs replicated computations on a distributed TPU system.

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -154,7 +155,10 @@ global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host. dimension) the array lists the global ids of the TPUs on that host.
)doc"); )doc");
REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc( REGISTER_OP("_ShutdownDistributedTPU")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running. This Op must be run on the same an error if no system is running. This Op must be run on the same
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
@ -184,6 +188,7 @@ tpu_ids: A vector containing the global TPU id of each TPU on the host.
REGISTER_OP("_DisconnectHostFromDistributedTPUSystem") REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
.Output("number_of_tpu_chips: int32") .Output("number_of_tpu_chips: int32")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
An op that disconnects the TPUs on a host from a running distributed An op that disconnects the TPUs on a host from a running distributed
TPU system. TPU system.
@ -196,6 +201,7 @@ REGISTER_OP("ConfigureDistributedTPU")
.Output("global_tpu_array: int32") .Output("global_tpu_array: int32")
.Attr("embedding_config: string = ''") .Attr("embedding_config: string = ''")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
An op that sets up the centralized structures for a distributed TPU An op that sets up the centralized structures for a distributed TPU
system. system.
@ -205,7 +211,10 @@ dimension) the array lists the global ids of the TPUs on that host.
embedding_config: Internal use. embedding_config: Internal use.
)doc"); )doc");
REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc( REGISTER_OP("ShutdownDistributedTPU")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running. an error if no system is running.
)doc"); )doc");

View File

@ -282,6 +282,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
Status status; Status status;
Node* times_two = s.graph()->AddNode(def, &status); Node* times_two = s.graph()->AddNode(def, &status);
TF_ASSERT_OK(status); TF_ASSERT_OK(status);
TF_ASSERT_OK(s.DoShapeInference(times_two));
s.graph()->AddEdge(c.node(), 0, times_two, 0); s.graph()->AddEdge(c.node(), 0, times_two, 0);
auto times_two_send = auto times_two_send =
@ -297,7 +298,10 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
EXPECT_FALSE(was_mutated); EXPECT_FALSE(was_mutated);
} }
REGISTER_OP("ConstantFoldingTestOp").Input("a: int64").Output("b: int64"); REGISTER_OP("ConstantFoldingTestOp")
.Input("a: int64")
.Output("b: int64")
.SetShapeFn(shape_inference::UnknownShape);
TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) { TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
Graph g(OpRegistry::Global()); Graph g(OpRegistry::Global());
@ -312,6 +316,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
Status status; Status status;
Node* non_cpu = s.graph()->AddNode(def, &status); Node* non_cpu = s.graph()->AddNode(def, &status);
TF_ASSERT_OK(status); TF_ASSERT_OK(status);
TF_ASSERT_OK(s.DoShapeInference(non_cpu));
auto non_cpu_send = auto non_cpu_send =
ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu), ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu),

View File

@ -284,6 +284,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name,
Status status; Status status;
Node* n = scope->graph()->AddNode(def, &status); Node* n = scope->graph()->AddNode(def, &status);
TF_CHECK_OK(status); TF_CHECK_OK(status);
TF_CHECK_OK(scope->DoShapeInference(n));
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < inputs.size(); ++i) {
scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i); scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i);
} }
@ -989,7 +990,7 @@ TEST(OptimizationTest, RemoveDeadNodes) {
GraphDef expected; GraphDef expected;
{ {
Scope s = Scope::NewRootScope(); Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1); auto o = ops::Const(s.WithOpName("o"), 1);
auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT); auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
@ -1070,7 +1071,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
{ {
Scope s = Scope::NewRootScope(); Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1); auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x); auto a = ops::Square(s.WithOpName("a"), x);
@ -1087,7 +1088,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
} }
{ {
Scope s = Scope::NewRootScope(); Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1); auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x); auto a = ops::Square(s.WithOpName("a"), x);
@ -1137,7 +1138,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
{{"o", "o:sum"}}); {{"o", "o:sum"}});
{ {
Scope scope = Scope::NewRootScope(); Scope scope = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0); auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
auto zero = ops::Const(scope.WithOpName("zero"), 0); auto zero = ops::Const(scope.WithOpName("zero"), 0);
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4); auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
@ -1222,7 +1223,7 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
{{"o", "o:sum"}}); {{"o", "o:sum"}});
{ {
Scope s = Scope::NewRootScope(); Scope s = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0); auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
auto dummy = ops::Const(s.WithOpName("dummy"), 0); auto dummy = ops::Const(s.WithOpName("dummy"), 0);
auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy), auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),

View File

@ -64,7 +64,7 @@ TEST_F(GpuStreamUtilTest, EmptyGraph) {
} }
TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) { TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
auto root = Scope::NewRootScope().ExitOnError(); auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::MatMul(root, {}, {}); ops::MatMul(root, {}, {});
Graph g(OpRegistry::Global()); Graph g(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&g)); TF_ASSERT_OK(root.ToGraph(&g));
@ -83,7 +83,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
} }
TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) { TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
auto root = Scope::NewRootScope().ExitOnError(); auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::MatMul(root, {}, {}); ops::MatMul(root, {}, {});
Graph g(OpRegistry::Global()); Graph g(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&g)); TF_ASSERT_OK(root.ToGraph(&g));
@ -104,7 +104,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
} }
TEST_F(GpuStreamUtilTest, StreamOverrides) { TEST_F(GpuStreamUtilTest, StreamOverrides) {
auto root = Scope::NewRootScope().ExitOnError(); auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0, ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
"/device:GPU:0"); "/device:GPU:0");
Output n = ops::MatMul(root, {}, {}); Output n = ops::MatMul(root, {}, {});

View File

@ -882,7 +882,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) {
} }
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) { TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
Scope root = Scope::NewRootScope(); Scope root = Scope::DisabledShapeInferenceScope();
Node* scalar_non_const; Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32") TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
.Finalize(root.graph(), &scalar_non_const)); .Finalize(root.graph(), &scalar_non_const));
@ -914,7 +914,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
} }
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) { TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
Scope root = Scope::NewRootScope(); Scope root = Scope::DisabledShapeInferenceScope();
Node* scalar_non_const; Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64") TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
.Finalize(root.graph(), &scalar_non_const)); .Finalize(root.graph(), &scalar_non_const));
@ -997,7 +997,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
} }
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) { TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
Scope root = Scope::NewRootScope(); Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph(); Graph* g = root.graph();
Node* partial_1; Node* partial_1;
Node* partial_2; Node* partial_2;
@ -1034,7 +1034,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
} }
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) { TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
Scope root = Scope::NewRootScope(); Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph(); Graph* g = root.graph();
Node* scalar_non_const; Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32") TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
@ -1077,7 +1077,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
} }
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
Scope root = Scope::NewRootScope(); Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph(); Graph* g = root.graph();
Node* scalar_non_const; Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32") TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/random_ops.h" #include "tensorflow/cc/ops/random_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/framework/versions.pb.h"
@ -141,9 +142,17 @@ void CheckLoopConstruction(const GraphDef& graph_def) {
} }
} }
REGISTER_OP("FloatInput").Output("o: float"); REGISTER_OP("FloatInput")
REGISTER_OP("BoolInput").Output("o: bool"); .Output("o: float")
REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float"); .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("BoolInput")
.Output("o: bool")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("Combine")
.Input("a: float")
.Input("b: float")
.Output("o: float")
.SetShapeFn(shape_inference::UnknownShape);
Output ConstructOp(const Scope& scope, const string& op_type, Output ConstructOp(const Scope& scope, const string& op_type,
const gtl::ArraySlice<Input>& inputs) { const gtl::ArraySlice<Input>& inputs) {
@ -158,6 +167,8 @@ Output ConstructOp(const Scope& scope, const string& op_type,
Node* ret; Node* ret;
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return Output(); if (!scope.ok()) return Output();
scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return Output();
return Output(ret); return Output(ret);
} }

View File

@ -28,7 +28,7 @@ namespace {
class AutoParallelTest : public ::testing::Test {}; class AutoParallelTest : public ::testing::Test {};
TEST_F(AutoParallelTest, SimpleParallel) { TEST_F(AutoParallelTest, SimpleParallel) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1}); Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1}); Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT); Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);

View File

@ -35,7 +35,7 @@ namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces) using namespace ops; // NOLINT(build/namespaces)
TEST(EncodeWavOpTest, EncodeWavTest) { TEST(EncodeWavOpTest, EncodeWavTest) {
Scope root = Scope::NewRootScope(); Scope root = Scope::DisabledShapeInferenceScope();
Tensor audio_tensor(DT_FLOAT, {4, 2}); Tensor audio_tensor(DT_FLOAT, {4, 2});
test::FillValues<float>( test::FillValues<float>(

View File

@ -88,7 +88,7 @@ class FuzzSession {
} }
initialized_ = true; initialized_ = true;
Scope root = Scope::NewRootScope().ExitOnError(); Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
SessionOptions options; SessionOptions options;
session_ = std::unique_ptr<Session>(NewSession(options)); session_ = std::unique_ptr<Session>(NewSession(options));

View File

@ -121,7 +121,7 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
const TensorShape kBadTensorShape({40, 100}); const TensorShape kBadTensorShape({40, 100});
const TensorShape kTestTensorShapeT({1, 4}); const TensorShape kTestTensorShapeT({1, 4});
auto root = Scope::NewRootScope().ExitOnError(); auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
auto node1 = auto node1 =
ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test:///2"); ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test:///2");
auto node2 = auto node2 =

View File

@ -35,7 +35,7 @@ namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces) using namespace ops; // NOLINT(build/namespaces)
TEST(MfccOpTest, SimpleTest) { TEST(MfccOpTest, SimpleTest) {
Scope root = Scope::NewRootScope(); Scope root = Scope::DisabledShapeInferenceScope();
Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513})); Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513}));
test::FillIota<float>(&spectrogram_tensor, 1.0f); test::FillIota<float>(&spectrogram_tensor, 1.0f);

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
namespace tensorflow { namespace tensorflow {
@ -26,6 +27,7 @@ REGISTER_OP("_Send")
.Attr("recv_device: string") .Attr("recv_device: string")
.Attr("client_terminated: bool = false") .Attr("client_terminated: bool = false")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
Sends the named tensor from send_device to recv_device. Sends the named tensor from send_device to recv_device.
@ -49,6 +51,7 @@ REGISTER_OP("_Recv")
.Attr("recv_device: string") .Attr("recv_device: string")
.Attr("client_terminated: bool = false") .Attr("client_terminated: bool = false")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
Receives the named tensor from send_device on recv_device. Receives the named tensor from send_device on recv_device.
@ -72,6 +75,7 @@ REGISTER_OP("_HostSend")
.Attr("recv_device: string") .Attr("recv_device: string")
.Attr("client_terminated: bool = false") .Attr("client_terminated: bool = false")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
Sends the named tensor from send_device to recv_device. Sends the named tensor from send_device to recv_device.
@ -98,6 +102,7 @@ REGISTER_OP("_HostRecv")
.Attr("recv_device: string") .Attr("recv_device: string")
.Attr("client_terminated: bool = false") .Attr("client_terminated: bool = false")
.SetIsStateful() .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc( .Doc(R"doc(
Receives the named tensor from send_device on recv_device. Receives the named tensor from send_device on recv_device.

View File

@ -36,7 +36,7 @@ class FakeQuantizeTrainingTest : public ::testing::Test {};
// TODO(suharshs): Once we implement the fake_quantize_training transform // TODO(suharshs): Once we implement the fake_quantize_training transform
// using the GTT, write proper tests of the transform here. // using the GTT, write proper tests of the transform here.
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) { TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
auto root = tensorflow::Scope::NewRootScope(); auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces) using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor a_data(DT_FLOAT, TensorShape()); Tensor a_data(DT_FLOAT, TensorShape());

View File

@ -40,7 +40,7 @@ class QuantizeWeightsTest : public ::testing::Test {
const TensorShape& weight_shape, const TensorShape& weight_shape,
std::initializer_list<float> weight_values, std::initializer_list<float> weight_values,
GraphDef* original_graph_def) { GraphDef* original_graph_def) {
auto root = tensorflow::Scope::NewRootScope(); auto root = tensorflow::Scope::DisabledShapeInferenceScope();
Tensor input_data(DT_FLOAT, input_shape); Tensor input_data(DT_FLOAT, input_shape);
test::FillValues<float>(&input_data, input_values); test::FillValues<float>(&input_data, input_values);

View File

@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test {
} }
void TestRenameNodeInputsWithWildcard() { void TestRenameNodeInputsWithWildcard() {
auto root = tensorflow::Scope::NewRootScope(); auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces) using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10; const int width = 10;