C++ API: run shape inference as nodes are constructed

Here's an example of the new generated code:

AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) {
  if (!scope.ok()) return;
  auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs);
  if (!scope.ok()) return;
  ::tensorflow::Node* ret;
  const auto unique_name = scope.GetUniqueNameForOp("AddN");
  auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN")
                     .Input(_inputs)
  ;
  scope.UpdateBuilder(&builder);
  scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
  if (!scope.ok()) return;
  scope.UpdateStatus(scope.DoShapeInference(ret));
  this->sum = Output(ret, 0);
}

Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it.

PiperOrigin-RevId: 165378429
This commit is contained in:
Skye Wanderman-Milne 2017-08-15 16:46:16 -07:00 committed by TensorFlower Gardener
parent 9ba0abc2f0
commit 477d49c9ea
26 changed files with 143 additions and 50 deletions

View File

@ -812,12 +812,8 @@ string OpInfo::GetConstructorBody() const {
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
scope_str, ".graph(), &ret));\n");
strings::StrAppend(&body, " ", return_on_error, "\n");
// TODO(b/28152992): Enable this code-path once we have converted
// all python shape functions to call their C++ versions.
// strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
// ".refiner()->AddNode(ret));\n");
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
".DoShapeInference(ret));\n");
GetOutput(&body);
return body;

View File

@ -37,13 +37,14 @@ Scope& Scope::operator=(const Scope& other) {
}
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
ShapeRefiner* refiner)
ShapeRefiner* refiner, bool disable_shape_inference)
: graph_(graph),
status_(status),
name_map_(name_map),
refiner_(refiner),
scope_used_(nullptr),
colocation_constraints_() {}
colocation_constraints_(),
disable_shape_inference_(disable_shape_inference) {}
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
const std::shared_ptr<Status>& status,
@ -54,13 +55,23 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
name_map_(name_map),
refiner_(refiner),
scope_used_(nullptr),
colocation_constraints_() {}
colocation_constraints_(),
disable_shape_inference_(false) {}
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner =
new ShapeRefiner(graph->versions(), graph->op_registry());
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
/* disable_shape_inference */ false));
}
Scope Scope::DisabledShapeInferenceScope() {
Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner =
new ShapeRefiner(graph->versions(), graph->op_registry());
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
/* disable_shape_inference */ true));
}
Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
@ -77,7 +88,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
const string& op_name)
@ -92,7 +104,8 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
std::vector<Operation> control_deps, bool clear_control_deps)
@ -113,7 +126,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
: graph_(other.impl()->graph_),
@ -127,7 +141,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(device),
colocation_constraints_(other.impl()->colocation_constraints_) {}
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
const string& op_name)
@ -142,7 +157,8 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
: graph_(other.impl()->graph_),
@ -156,7 +172,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
const string& kernel_label)
@ -171,7 +188,8 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.impl()->device_),
colocation_constraints_(other.impl()->colocation_constraints_) {}
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
Scope::Impl::Impl(const Scope& other, Tags::Colocate,
const Operation& colocate_with_op, bool clear_colocations)
@ -189,7 +207,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
: other.impl()->GetColocationConstraints(colocate_with_op)) {}
: other.impl()->GetColocationConstraints(colocate_with_op)),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
@ -404,6 +423,11 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
}
}
Status Scope::DoShapeInference(Node* node) const {
if (impl_->disable_shape_inference_) return Status::OK();
return impl_->refiner_->AddNode(node);
}
class InternalScope {
public:
// NewScope doesn't take ownership of the inputs.

View File

@ -199,6 +199,18 @@ class Scope {
// edges from the source and to the sink node, resolves back edges
// by name), and makes sure the resulting graph is valid.
Status ToGraph(Graph* g) const;
// Calls AddNode() using this scope's ShapeRefiner. This exists in the public
// API to prevent custom op wrappers from needing access to shape_refiner.h or
// scope_internal.h.
// TODO(skyewm): remove this from public API
Status DoShapeInference(Node* node) const;
// Creates a new root scope that causes all DoShapeInference() calls to return
// Status::OK() (on the returned scope and any subscopes). Used for testing.
// TODO(skyewm): fix tests that still require this and eventually remove, or
// at least remove from public API
static Scope DisabledShapeInferenceScope();
// END_SKIP_DOXYGEN
const std::vector<Operation>& control_deps() const;

View File

@ -58,7 +58,8 @@ class Scope::Impl {
enum class Colocate;
};
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner);
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
bool disable_shape_inference);
Impl(const Scope& other, Tags::ScopeName, const string& name,
bool copy_names);
Impl(const Scope& other, Tags::OpName, const string& name,
@ -103,6 +104,10 @@ class Scope::Impl {
const string kernel_label_ = "";
const string device_ = "";
const std::unordered_set<string> colocation_constraints_;
// If true, Scope::DoShapeInference() always returns Status:OK().
// TODO(skyewm): remove this when possible
const bool disable_shape_inference_;
};
} // namespace tensorflow

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@ -24,6 +25,7 @@ REGISTER_OP("ThrowAway1")
.Attr("scope: int")
.Attr("builder: int = 1")
.Attr("while: int")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Op to test keywords and reserved words in input and attr names.
@ -36,12 +38,20 @@ REGISTER_OP("ThrowAway2")
.Attr("scope: int = 2")
.Attr("throw_away2: int = 2")
.Attr("attrs: int = 4")
.Attr("node: int = 4");
.Attr("node: int = 4")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("ThrowAway3").Output("node: int32");
REGISTER_OP("ThrowAway3")
.Output("node: int32")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("ThrowAway4").Input("node: int32");
REGISTER_OP("ThrowAway4")
.Input("node: int32")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4");
REGISTER_OP("ThrowAway5")
.Output("foo: int32")
.Attr("node: int = 4")
.SetShapeFn(shape_inference::UnknownShape);
} // namespace tensorflow

View File

@ -34,7 +34,9 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
.Attr("dtype", val.tensor.dtype());
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(graph, &ret));
if (!scope.ok()) return Output();
scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return Output();
return Output(ret);

View File

@ -56,6 +56,8 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
scope.UpdateBuilder(&cast_builder);
Node* ret;
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return Output();
scope.UpdateStatus(scope.DoShapeInference(ret));
return Output(ret, 0);
}

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@ -26,6 +27,7 @@ REGISTER_OP("XlaWhile")
.Attr("cond: func")
.Attr("body: func")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = input; While (Cond(output)) { output = Body(output) }

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@ -22,6 +23,7 @@ REGISTER_OP("_XLASend")
.Attr("T: type")
.Attr("tensor_name: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor to another XLA computation.
@ -35,6 +37,7 @@ REGISTER_OP("_XLARecv")
.Attr("tensor_name: string")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Receives the named tensor from another XLA computation.

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph.h"
@ -76,6 +77,8 @@ class DummyReadResourceCC {
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return;
this->output_ = Output(ret, 0);
}
Node* node() const { return output_.node(); }
@ -86,6 +89,7 @@ class DummyReadResourceCC {
REGISTER_OP("DummyReadResource")
.Input("input: int32")
.Output("output: int32")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
A dummy Op.

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@ -67,6 +68,7 @@ REGISTER_OP("TPUReplicate")
.Input("broadcast_inputs: Tbroadcast_inputs")
.Input("variables: NumVariables * resource")
.Output("outputs: output_types")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Runs replicated computations on a distributed TPU system.

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
@ -154,7 +155,10 @@ global_tpu_array: A two-dimensional array. For each host (the outer
dimension) the array lists the global ids of the TPUs on that host.
)doc");
REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
REGISTER_OP("_ShutdownDistributedTPU")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running. This Op must be run on the same
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
@ -184,6 +188,7 @@ tpu_ids: A vector containing the global TPU id of each TPU on the host.
REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
.Output("number_of_tpu_chips: int32")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
An op that disconnects the TPUs on a host from a running distributed
TPU system.
@ -196,6 +201,7 @@ REGISTER_OP("ConfigureDistributedTPU")
.Output("global_tpu_array: int32")
.Attr("embedding_config: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
An op that sets up the centralized structures for a distributed TPU
system.
@ -205,7 +211,10 @@ dimension) the array lists the global ids of the TPUs on that host.
embedding_config: Internal use.
)doc");
REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
REGISTER_OP("ShutdownDistributedTPU")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
An op that shuts down a running distributed TPU system. The Op returns
an error if no system is running.
)doc");

View File

@ -282,6 +282,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
Status status;
Node* times_two = s.graph()->AddNode(def, &status);
TF_ASSERT_OK(status);
TF_ASSERT_OK(s.DoShapeInference(times_two));
s.graph()->AddEdge(c.node(), 0, times_two, 0);
auto times_two_send =
@ -297,7 +298,10 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
EXPECT_FALSE(was_mutated);
}
REGISTER_OP("ConstantFoldingTestOp").Input("a: int64").Output("b: int64");
REGISTER_OP("ConstantFoldingTestOp")
.Input("a: int64")
.Output("b: int64")
.SetShapeFn(shape_inference::UnknownShape);
TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
Graph g(OpRegistry::Global());
@ -312,6 +316,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
Status status;
Node* non_cpu = s.graph()->AddNode(def, &status);
TF_ASSERT_OK(status);
TF_ASSERT_OK(s.DoShapeInference(non_cpu));
auto non_cpu_send =
ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu),

View File

@ -284,6 +284,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name,
Status status;
Node* n = scope->graph()->AddNode(def, &status);
TF_CHECK_OK(status);
TF_CHECK_OK(scope->DoShapeInference(n));
for (int i = 0; i < inputs.size(); ++i) {
scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i);
}
@ -989,7 +990,7 @@ TEST(OptimizationTest, RemoveDeadNodes) {
GraphDef expected;
{
Scope s = Scope::NewRootScope();
Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
@ -1070,7 +1071,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
{
Scope s = Scope::NewRootScope();
Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x);
@ -1087,7 +1088,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
}
{
Scope s = Scope::NewRootScope();
Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x);
@ -1137,7 +1138,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
{{"o", "o:sum"}});
{
Scope scope = Scope::NewRootScope();
Scope scope = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
auto zero = ops::Const(scope.WithOpName("zero"), 0);
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
@ -1222,7 +1223,7 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
{{"o", "o:sum"}});
{
Scope s = Scope::NewRootScope();
Scope s = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
auto dummy = ops::Const(s.WithOpName("dummy"), 0);
auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),

View File

@ -64,7 +64,7 @@ TEST_F(GpuStreamUtilTest, EmptyGraph) {
}
TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
auto root = Scope::NewRootScope().ExitOnError();
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::MatMul(root, {}, {});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&g));
@ -83,7 +83,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
}
TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
auto root = Scope::NewRootScope().ExitOnError();
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::MatMul(root, {}, {});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&g));
@ -104,7 +104,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
}
TEST_F(GpuStreamUtilTest, StreamOverrides) {
auto root = Scope::NewRootScope().ExitOnError();
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
"/device:GPU:0");
Output n = ops::MatMul(root, {}, {});

View File

@ -882,7 +882,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
Scope root = Scope::NewRootScope();
Scope root = Scope::DisabledShapeInferenceScope();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
.Finalize(root.graph(), &scalar_non_const));
@ -914,7 +914,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
Scope root = Scope::NewRootScope();
Scope root = Scope::DisabledShapeInferenceScope();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
.Finalize(root.graph(), &scalar_non_const));
@ -997,7 +997,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
Scope root = Scope::NewRootScope();
Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph();
Node* partial_1;
Node* partial_2;
@ -1034,7 +1034,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
Scope root = Scope::NewRootScope();
Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
@ -1077,7 +1077,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
}
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
Scope root = Scope::NewRootScope();
Scope root = Scope::DisabledShapeInferenceScope();
Graph* g = root.graph();
Node* scalar_non_const;
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/random_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/versions.pb.h"
@ -141,9 +142,17 @@ void CheckLoopConstruction(const GraphDef& graph_def) {
}
}
REGISTER_OP("FloatInput").Output("o: float");
REGISTER_OP("BoolInput").Output("o: bool");
REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float");
REGISTER_OP("FloatInput")
.Output("o: float")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("BoolInput")
.Output("o: bool")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("Combine")
.Input("a: float")
.Input("b: float")
.Output("o: float")
.SetShapeFn(shape_inference::UnknownShape);
Output ConstructOp(const Scope& scope, const string& op_type,
const gtl::ArraySlice<Input>& inputs) {
@ -158,6 +167,8 @@ Output ConstructOp(const Scope& scope, const string& op_type,
Node* ret;
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return Output();
scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return Output();
return Output(ret);
}

View File

@ -28,7 +28,7 @@ namespace {
class AutoParallelTest : public ::testing::Test {};
TEST_F(AutoParallelTest, SimpleParallel) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);

View File

@ -35,7 +35,7 @@ namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
TEST(EncodeWavOpTest, EncodeWavTest) {
Scope root = Scope::NewRootScope();
Scope root = Scope::DisabledShapeInferenceScope();
Tensor audio_tensor(DT_FLOAT, {4, 2});
test::FillValues<float>(

View File

@ -88,7 +88,7 @@ class FuzzSession {
}
initialized_ = true;
Scope root = Scope::NewRootScope().ExitOnError();
Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
SessionOptions options;
session_ = std::unique_ptr<Session>(NewSession(options));

View File

@ -121,7 +121,7 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
const TensorShape kBadTensorShape({40, 100});
const TensorShape kTestTensorShapeT({1, 4});
auto root = Scope::NewRootScope().ExitOnError();
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
auto node1 =
ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test:///2");
auto node2 =

View File

@ -35,7 +35,7 @@ namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
TEST(MfccOpTest, SimpleTest) {
Scope root = Scope::NewRootScope();
Scope root = Scope::DisabledShapeInferenceScope();
Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513}));
test::FillIota<float>(&spectrogram_tensor, 1.0f);

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
@ -26,6 +27,7 @@ REGISTER_OP("_Send")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor from send_device to recv_device.
@ -49,6 +51,7 @@ REGISTER_OP("_Recv")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Receives the named tensor from send_device on recv_device.
@ -72,6 +75,7 @@ REGISTER_OP("_HostSend")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Sends the named tensor from send_device to recv_device.
@ -98,6 +102,7 @@ REGISTER_OP("_HostRecv")
.Attr("recv_device: string")
.Attr("client_terminated: bool = false")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Receives the named tensor from send_device on recv_device.

View File

@ -36,7 +36,7 @@ class FakeQuantizeTrainingTest : public ::testing::Test {};
// TODO(suharshs): Once we implement the fake_quantize_training transform
// using the GTT, write proper tests of the transform here.
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
auto root = tensorflow::Scope::NewRootScope();
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor a_data(DT_FLOAT, TensorShape());

View File

@ -40,7 +40,7 @@ class QuantizeWeightsTest : public ::testing::Test {
const TensorShape& weight_shape,
std::initializer_list<float> weight_values,
GraphDef* original_graph_def) {
auto root = tensorflow::Scope::NewRootScope();
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
Tensor input_data(DT_FLOAT, input_shape);
test::FillValues<float>(&input_data, input_values);

View File

@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test {
}
void TestRenameNodeInputsWithWildcard() {
auto root = tensorflow::Scope::NewRootScope();
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;