C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code: AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) { if (!scope.ok()) return; auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs); if (!scope.ok()) return; ::tensorflow::Node* ret; const auto unique_name = scope.GetUniqueNameForOp("AddN"); auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN") .Input(_inputs) ; scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); this->sum = Output(ret, 0); } Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it. PiperOrigin-RevId: 165378429
This commit is contained in:
parent
9ba0abc2f0
commit
477d49c9ea
tensorflow
cc
compiler/tf2xla
contrib/tpu/ops
core
common_runtime
graph
grappler/optimizers
kernels
ops
tools/graph_transforms
@ -812,12 +812,8 @@ string OpInfo::GetConstructorBody() const {
|
||||
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
|
||||
scope_str, ".graph(), &ret));\n");
|
||||
strings::StrAppend(&body, " ", return_on_error, "\n");
|
||||
|
||||
// TODO(b/28152992): Enable this code-path once we have converted
|
||||
// all python shape functions to call their C++ versions.
|
||||
|
||||
// strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
|
||||
// ".refiner()->AddNode(ret));\n");
|
||||
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
|
||||
".DoShapeInference(ret));\n");
|
||||
|
||||
GetOutput(&body);
|
||||
return body;
|
||||
|
@ -37,13 +37,14 @@ Scope& Scope::operator=(const Scope& other) {
|
||||
}
|
||||
|
||||
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
|
||||
ShapeRefiner* refiner)
|
||||
ShapeRefiner* refiner, bool disable_shape_inference)
|
||||
: graph_(graph),
|
||||
status_(status),
|
||||
name_map_(name_map),
|
||||
refiner_(refiner),
|
||||
scope_used_(nullptr),
|
||||
colocation_constraints_() {}
|
||||
colocation_constraints_(),
|
||||
disable_shape_inference_(disable_shape_inference) {}
|
||||
|
||||
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
|
||||
const std::shared_ptr<Status>& status,
|
||||
@ -54,13 +55,23 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
|
||||
name_map_(name_map),
|
||||
refiner_(refiner),
|
||||
scope_used_(nullptr),
|
||||
colocation_constraints_() {}
|
||||
colocation_constraints_(),
|
||||
disable_shape_inference_(false) {}
|
||||
|
||||
Scope Scope::NewRootScope() {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
ShapeRefiner* refiner =
|
||||
new ShapeRefiner(graph->versions(), graph->op_registry());
|
||||
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
|
||||
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
|
||||
/* disable_shape_inference */ false));
|
||||
}
|
||||
|
||||
Scope Scope::DisabledShapeInferenceScope() {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
ShapeRefiner* refiner =
|
||||
new ShapeRefiner(graph->versions(), graph->op_registry());
|
||||
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
|
||||
/* disable_shape_inference */ true));
|
||||
}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
|
||||
@ -77,7 +88,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
|
||||
exit_on_error_(other.impl()->exit_on_error_),
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
|
||||
const string& op_name)
|
||||
@ -92,7 +104,8 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
|
||||
exit_on_error_(other.impl()->exit_on_error_),
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
|
||||
std::vector<Operation> control_deps, bool clear_control_deps)
|
||||
@ -113,7 +126,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
|
||||
exit_on_error_(other.impl()->exit_on_error_),
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
|
||||
: graph_(other.impl()->graph_),
|
||||
@ -127,7 +141,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
|
||||
exit_on_error_(other.impl()->exit_on_error_),
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(device),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
|
||||
const string& op_name)
|
||||
@ -142,7 +157,8 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
|
||||
exit_on_error_(other.impl()->exit_on_error_),
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
|
||||
: graph_(other.impl()->graph_),
|
||||
@ -156,7 +172,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
|
||||
exit_on_error_(true),
|
||||
kernel_label_(other.impl()->kernel_label_),
|
||||
device_(other.impl()->device_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
|
||||
const string& kernel_label)
|
||||
@ -171,7 +188,8 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
|
||||
exit_on_error_(other.impl()->exit_on_error_),
|
||||
kernel_label_(kernel_label),
|
||||
device_(other.impl()->device_),
|
||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
||||
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
Scope::Impl::Impl(const Scope& other, Tags::Colocate,
|
||||
const Operation& colocate_with_op, bool clear_colocations)
|
||||
@ -189,7 +207,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
|
||||
colocation_constraints_(
|
||||
clear_colocations
|
||||
? std::unordered_set<string>()
|
||||
: other.impl()->GetColocationConstraints(colocate_with_op)) {}
|
||||
: other.impl()->GetColocationConstraints(colocate_with_op)),
|
||||
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||
|
||||
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
|
||||
const Operation& colocate_with_op) const {
|
||||
@ -404,6 +423,11 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
|
||||
}
|
||||
}
|
||||
|
||||
Status Scope::DoShapeInference(Node* node) const {
|
||||
if (impl_->disable_shape_inference_) return Status::OK();
|
||||
return impl_->refiner_->AddNode(node);
|
||||
}
|
||||
|
||||
class InternalScope {
|
||||
public:
|
||||
// NewScope doesn't take ownership of the inputs.
|
||||
|
@ -199,6 +199,18 @@ class Scope {
|
||||
// edges from the source and to the sink node, resolves back edges
|
||||
// by name), and makes sure the resulting graph is valid.
|
||||
Status ToGraph(Graph* g) const;
|
||||
|
||||
// Calls AddNode() using this scope's ShapeRefiner. This exists in the public
|
||||
// API to prevent custom op wrappers from needing access to shape_refiner.h or
|
||||
// scope_internal.h.
|
||||
// TODO(skyewm): remove this from public API
|
||||
Status DoShapeInference(Node* node) const;
|
||||
|
||||
// Creates a new root scope that causes all DoShapeInference() calls to return
|
||||
// Status::OK() (on the returned scope and any subscopes). Used for testing.
|
||||
// TODO(skyewm): fix tests that still require this and eventually remove, or
|
||||
// at least remove from public API
|
||||
static Scope DisabledShapeInferenceScope();
|
||||
// END_SKIP_DOXYGEN
|
||||
|
||||
const std::vector<Operation>& control_deps() const;
|
||||
|
@ -58,7 +58,8 @@ class Scope::Impl {
|
||||
enum class Colocate;
|
||||
};
|
||||
|
||||
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner);
|
||||
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
|
||||
bool disable_shape_inference);
|
||||
Impl(const Scope& other, Tags::ScopeName, const string& name,
|
||||
bool copy_names);
|
||||
Impl(const Scope& other, Tags::OpName, const string& name,
|
||||
@ -103,6 +104,10 @@ class Scope::Impl {
|
||||
const string kernel_label_ = "";
|
||||
const string device_ = "";
|
||||
const std::unordered_set<string> colocation_constraints_;
|
||||
|
||||
// If true, Scope::DoShapeInference() always returns Status:OK().
|
||||
// TODO(skyewm): remove this when possible
|
||||
const bool disable_shape_inference_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -24,6 +25,7 @@ REGISTER_OP("ThrowAway1")
|
||||
.Attr("scope: int")
|
||||
.Attr("builder: int = 1")
|
||||
.Attr("while: int")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Op to test keywords and reserved words in input and attr names.
|
||||
|
||||
@ -36,12 +38,20 @@ REGISTER_OP("ThrowAway2")
|
||||
.Attr("scope: int = 2")
|
||||
.Attr("throw_away2: int = 2")
|
||||
.Attr("attrs: int = 4")
|
||||
.Attr("node: int = 4");
|
||||
.Attr("node: int = 4")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("ThrowAway3").Output("node: int32");
|
||||
REGISTER_OP("ThrowAway3")
|
||||
.Output("node: int32")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("ThrowAway4").Input("node: int32");
|
||||
REGISTER_OP("ThrowAway4")
|
||||
.Input("node: int32")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4");
|
||||
REGISTER_OP("ThrowAway5")
|
||||
.Output("foo: int32")
|
||||
.Attr("node: int = 4")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -34,7 +34,9 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
|
||||
.Attr("dtype", val.tensor.dtype());
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(graph, &ret));
|
||||
if (!scope.ok()) return Output();
|
||||
|
||||
scope.UpdateStatus(scope.DoShapeInference(ret));
|
||||
if (!scope.ok()) return Output();
|
||||
|
||||
return Output(ret);
|
||||
|
@ -56,6 +56,8 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
|
||||
scope.UpdateBuilder(&cast_builder);
|
||||
Node* ret;
|
||||
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
|
||||
if (!scope.ok()) return Output();
|
||||
scope.UpdateStatus(scope.DoShapeInference(ret));
|
||||
return Output(ret, 0);
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -26,6 +27,7 @@ REGISTER_OP("XlaWhile")
|
||||
.Attr("cond: func")
|
||||
.Attr("body: func")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
output = input; While (Cond(output)) { output = Body(output) }
|
||||
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -22,6 +23,7 @@ REGISTER_OP("_XLASend")
|
||||
.Attr("T: type")
|
||||
.Attr("tensor_name: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Sends the named tensor to another XLA computation.
|
||||
|
||||
@ -35,6 +37,7 @@ REGISTER_OP("_XLARecv")
|
||||
.Attr("tensor_name: string")
|
||||
.Attr("shape: shape")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Receives the named tensor from another XLA computation.
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -76,6 +77,8 @@ class DummyReadResourceCC {
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
if (!scope.ok()) return;
|
||||
scope.UpdateStatus(scope.DoShapeInference(ret));
|
||||
if (!scope.ok()) return;
|
||||
this->output_ = Output(ret, 0);
|
||||
}
|
||||
Node* node() const { return output_.node(); }
|
||||
@ -86,6 +89,7 @@ class DummyReadResourceCC {
|
||||
REGISTER_OP("DummyReadResource")
|
||||
.Input("input: int32")
|
||||
.Output("output: int32")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
A dummy Op.
|
||||
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
@ -67,6 +68,7 @@ REGISTER_OP("TPUReplicate")
|
||||
.Input("broadcast_inputs: Tbroadcast_inputs")
|
||||
.Input("variables: NumVariables * resource")
|
||||
.Output("outputs: output_types")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Runs replicated computations on a distributed TPU system.
|
||||
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -154,7 +155,10 @@ global_tpu_array: A two-dimensional array. For each host (the outer
|
||||
dimension) the array lists the global ids of the TPUs on that host.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
|
||||
REGISTER_OP("_ShutdownDistributedTPU")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
An op that shuts down a running distributed TPU system. The Op returns
|
||||
an error if no system is running. This Op must be run on the same
|
||||
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
|
||||
@ -184,6 +188,7 @@ tpu_ids: A vector containing the global TPU id of each TPU on the host.
|
||||
REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
|
||||
.Output("number_of_tpu_chips: int32")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
An op that disconnects the TPUs on a host from a running distributed
|
||||
TPU system.
|
||||
@ -196,6 +201,7 @@ REGISTER_OP("ConfigureDistributedTPU")
|
||||
.Output("global_tpu_array: int32")
|
||||
.Attr("embedding_config: string = ''")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
An op that sets up the centralized structures for a distributed TPU
|
||||
system.
|
||||
@ -205,7 +211,10 @@ dimension) the array lists the global ids of the TPUs on that host.
|
||||
embedding_config: Internal use.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
|
||||
REGISTER_OP("ShutdownDistributedTPU")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
An op that shuts down a running distributed TPU system. The Op returns
|
||||
an error if no system is running.
|
||||
)doc");
|
||||
|
@ -282,6 +282,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
||||
Status status;
|
||||
Node* times_two = s.graph()->AddNode(def, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_ASSERT_OK(s.DoShapeInference(times_two));
|
||||
s.graph()->AddEdge(c.node(), 0, times_two, 0);
|
||||
|
||||
auto times_two_send =
|
||||
@ -297,7 +298,10 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
||||
EXPECT_FALSE(was_mutated);
|
||||
}
|
||||
|
||||
REGISTER_OP("ConstantFoldingTestOp").Input("a: int64").Output("b: int64");
|
||||
REGISTER_OP("ConstantFoldingTestOp")
|
||||
.Input("a: int64")
|
||||
.Output("b: int64")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
|
||||
Graph g(OpRegistry::Global());
|
||||
@ -312,6 +316,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
|
||||
Status status;
|
||||
Node* non_cpu = s.graph()->AddNode(def, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_ASSERT_OK(s.DoShapeInference(non_cpu));
|
||||
|
||||
auto non_cpu_send =
|
||||
ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu),
|
||||
|
@ -284,6 +284,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name,
|
||||
Status status;
|
||||
Node* n = scope->graph()->AddNode(def, &status);
|
||||
TF_CHECK_OK(status);
|
||||
TF_CHECK_OK(scope->DoShapeInference(n));
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i);
|
||||
}
|
||||
@ -989,7 +990,7 @@ TEST(OptimizationTest, RemoveDeadNodes) {
|
||||
|
||||
GraphDef expected;
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
Scope s = Scope::DisabledShapeInferenceScope();
|
||||
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
||||
auto o = ops::Const(s.WithOpName("o"), 1);
|
||||
auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
|
||||
@ -1070,7 +1071,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
|
||||
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
|
||||
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
Scope s = Scope::DisabledShapeInferenceScope();
|
||||
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
||||
auto o = ops::Const(s.WithOpName("o"), 1);
|
||||
auto a = ops::Square(s.WithOpName("a"), x);
|
||||
@ -1087,7 +1088,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
|
||||
}
|
||||
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
Scope s = Scope::DisabledShapeInferenceScope();
|
||||
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
||||
auto o = ops::Const(s.WithOpName("o"), 1);
|
||||
auto a = ops::Square(s.WithOpName("a"), x);
|
||||
@ -1137,7 +1138,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
|
||||
{{"o", "o:sum"}});
|
||||
|
||||
{
|
||||
Scope scope = Scope::NewRootScope();
|
||||
Scope scope = Scope::DisabledShapeInferenceScope();
|
||||
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
|
||||
auto zero = ops::Const(scope.WithOpName("zero"), 0);
|
||||
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
|
||||
@ -1222,7 +1223,7 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
||||
{{"o", "o:sum"}});
|
||||
|
||||
{
|
||||
Scope s = Scope::NewRootScope();
|
||||
Scope s = Scope::DisabledShapeInferenceScope();
|
||||
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
|
||||
auto dummy = ops::Const(s.WithOpName("dummy"), 0);
|
||||
auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),
|
||||
|
@ -64,7 +64,7 @@ TEST_F(GpuStreamUtilTest, EmptyGraph) {
|
||||
}
|
||||
|
||||
TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||
ops::MatMul(root, {}, {});
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(&g));
|
||||
@ -83,7 +83,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
|
||||
}
|
||||
|
||||
TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||
ops::MatMul(root, {}, {});
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(&g));
|
||||
@ -104,7 +104,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
|
||||
}
|
||||
|
||||
TEST_F(GpuStreamUtilTest, StreamOverrides) {
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
|
||||
"/device:GPU:0");
|
||||
Output n = ops::MatMul(root, {}, {});
|
||||
|
@ -882,7 +882,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) {
|
||||
}
|
||||
|
||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope root = Scope::DisabledShapeInferenceScope();
|
||||
Node* scalar_non_const;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||
.Finalize(root.graph(), &scalar_non_const));
|
||||
@ -914,7 +914,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
|
||||
}
|
||||
|
||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope root = Scope::DisabledShapeInferenceScope();
|
||||
Node* scalar_non_const;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
|
||||
.Finalize(root.graph(), &scalar_non_const));
|
||||
@ -997,7 +997,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
|
||||
}
|
||||
|
||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope root = Scope::DisabledShapeInferenceScope();
|
||||
Graph* g = root.graph();
|
||||
Node* partial_1;
|
||||
Node* partial_2;
|
||||
@ -1034,7 +1034,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
||||
}
|
||||
|
||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope root = Scope::DisabledShapeInferenceScope();
|
||||
Graph* g = root.graph();
|
||||
Node* scalar_non_const;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||
@ -1077,7 +1077,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
|
||||
}
|
||||
|
||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope root = Scope::DisabledShapeInferenceScope();
|
||||
Graph* g = root.graph();
|
||||
Node* scalar_non_const;
|
||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/cc/ops/random_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
@ -141,9 +142,17 @@ void CheckLoopConstruction(const GraphDef& graph_def) {
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_OP("FloatInput").Output("o: float");
|
||||
REGISTER_OP("BoolInput").Output("o: bool");
|
||||
REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float");
|
||||
REGISTER_OP("FloatInput")
|
||||
.Output("o: float")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
REGISTER_OP("BoolInput")
|
||||
.Output("o: bool")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
REGISTER_OP("Combine")
|
||||
.Input("a: float")
|
||||
.Input("b: float")
|
||||
.Output("o: float")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
Output ConstructOp(const Scope& scope, const string& op_type,
|
||||
const gtl::ArraySlice<Input>& inputs) {
|
||||
@ -158,6 +167,8 @@ Output ConstructOp(const Scope& scope, const string& op_type,
|
||||
Node* ret;
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
if (!scope.ok()) return Output();
|
||||
scope.UpdateStatus(scope.DoShapeInference(ret));
|
||||
if (!scope.ok()) return Output();
|
||||
return Output(ret);
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ namespace {
|
||||
class AutoParallelTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(AutoParallelTest, SimpleParallel) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
|
||||
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
|
||||
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
|
||||
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
|
||||
|
@ -35,7 +35,7 @@ namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
TEST(EncodeWavOpTest, EncodeWavTest) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope root = Scope::DisabledShapeInferenceScope();
|
||||
|
||||
Tensor audio_tensor(DT_FLOAT, {4, 2});
|
||||
test::FillValues<float>(
|
||||
|
@ -88,7 +88,7 @@ class FuzzSession {
|
||||
}
|
||||
initialized_ = true;
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||
SessionOptions options;
|
||||
session_ = std::unique_ptr<Session>(NewSession(options));
|
||||
|
||||
|
@ -121,7 +121,7 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
|
||||
const TensorShape kBadTensorShape({40, 100});
|
||||
const TensorShape kTestTensorShapeT({1, 4});
|
||||
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||
auto node1 =
|
||||
ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test:///2");
|
||||
auto node2 =
|
||||
|
@ -35,7 +35,7 @@ namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
TEST(MfccOpTest, SimpleTest) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Scope root = Scope::DisabledShapeInferenceScope();
|
||||
|
||||
Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513}));
|
||||
test::FillIota<float>(&spectrogram_tensor, 1.0f);
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -26,6 +27,7 @@ REGISTER_OP("_Send")
|
||||
.Attr("recv_device: string")
|
||||
.Attr("client_terminated: bool = false")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Sends the named tensor from send_device to recv_device.
|
||||
|
||||
@ -49,6 +51,7 @@ REGISTER_OP("_Recv")
|
||||
.Attr("recv_device: string")
|
||||
.Attr("client_terminated: bool = false")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Receives the named tensor from send_device on recv_device.
|
||||
|
||||
@ -72,6 +75,7 @@ REGISTER_OP("_HostSend")
|
||||
.Attr("recv_device: string")
|
||||
.Attr("client_terminated: bool = false")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Sends the named tensor from send_device to recv_device.
|
||||
|
||||
@ -98,6 +102,7 @@ REGISTER_OP("_HostRecv")
|
||||
.Attr("recv_device: string")
|
||||
.Attr("client_terminated: bool = false")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Receives the named tensor from send_device on recv_device.
|
||||
|
||||
|
@ -36,7 +36,7 @@ class FakeQuantizeTrainingTest : public ::testing::Test {};
|
||||
// TODO(suharshs): Once we implement the fake_quantize_training transform
|
||||
// using the GTT, write proper tests of the transform here.
|
||||
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
Tensor a_data(DT_FLOAT, TensorShape());
|
||||
|
@ -40,7 +40,7 @@ class QuantizeWeightsTest : public ::testing::Test {
|
||||
const TensorShape& weight_shape,
|
||||
std::initializer_list<float> weight_values,
|
||||
GraphDef* original_graph_def) {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
|
||||
|
||||
Tensor input_data(DT_FLOAT, input_shape);
|
||||
test::FillValues<float>(&input_data, input_values);
|
||||
|
@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
void TestRenameNodeInputsWithWildcard() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
const int width = 10;
|
||||
|
Loading…
Reference in New Issue
Block a user