C++ API: run shape inference as nodes are constructed
Here's an example of the new generated code: AddN::AddN(const ::tensorflow::Scope& scope, ::tensorflow::InputList inputs) { if (!scope.ok()) return; auto _inputs = ::tensorflow::ops::AsNodeOutList(scope, inputs); if (!scope.ok()) return; ::tensorflow::Node* ret; const auto unique_name = scope.GetUniqueNameForOp("AddN"); auto builder = ::tensorflow::NodeBuilder(unique_name, "AddN") .Input(_inputs) ; scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); this->sum = Output(ret, 0); } Enabling shape inference unfortunately broke many tests. I fixed some of them, but for others I introduced a Scope::DisabledShapeInferenceScope() static method that returns a scope that doesn't perform shape inference. Eventually we should fix the tests that use this and remove it. PiperOrigin-RevId: 165378429
This commit is contained in:
parent
9ba0abc2f0
commit
477d49c9ea
@ -812,12 +812,8 @@ string OpInfo::GetConstructorBody() const {
|
|||||||
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
|
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
|
||||||
scope_str, ".graph(), &ret));\n");
|
scope_str, ".graph(), &ret));\n");
|
||||||
strings::StrAppend(&body, " ", return_on_error, "\n");
|
strings::StrAppend(&body, " ", return_on_error, "\n");
|
||||||
|
strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
|
||||||
// TODO(b/28152992): Enable this code-path once we have converted
|
".DoShapeInference(ret));\n");
|
||||||
// all python shape functions to call their C++ versions.
|
|
||||||
|
|
||||||
// strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
|
|
||||||
// ".refiner()->AddNode(ret));\n");
|
|
||||||
|
|
||||||
GetOutput(&body);
|
GetOutput(&body);
|
||||||
return body;
|
return body;
|
||||||
|
@ -37,13 +37,14 @@ Scope& Scope::operator=(const Scope& other) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
|
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
|
||||||
ShapeRefiner* refiner)
|
ShapeRefiner* refiner, bool disable_shape_inference)
|
||||||
: graph_(graph),
|
: graph_(graph),
|
||||||
status_(status),
|
status_(status),
|
||||||
name_map_(name_map),
|
name_map_(name_map),
|
||||||
refiner_(refiner),
|
refiner_(refiner),
|
||||||
scope_used_(nullptr),
|
scope_used_(nullptr),
|
||||||
colocation_constraints_() {}
|
colocation_constraints_(),
|
||||||
|
disable_shape_inference_(disable_shape_inference) {}
|
||||||
|
|
||||||
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
|
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
|
||||||
const std::shared_ptr<Status>& status,
|
const std::shared_ptr<Status>& status,
|
||||||
@ -54,13 +55,23 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
|
|||||||
name_map_(name_map),
|
name_map_(name_map),
|
||||||
refiner_(refiner),
|
refiner_(refiner),
|
||||||
scope_used_(nullptr),
|
scope_used_(nullptr),
|
||||||
colocation_constraints_() {}
|
colocation_constraints_(),
|
||||||
|
disable_shape_inference_(false) {}
|
||||||
|
|
||||||
Scope Scope::NewRootScope() {
|
Scope Scope::NewRootScope() {
|
||||||
Graph* graph = new Graph(OpRegistry::Global());
|
Graph* graph = new Graph(OpRegistry::Global());
|
||||||
ShapeRefiner* refiner =
|
ShapeRefiner* refiner =
|
||||||
new ShapeRefiner(graph->versions(), graph->op_registry());
|
new ShapeRefiner(graph->versions(), graph->op_registry());
|
||||||
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner));
|
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
|
||||||
|
/* disable_shape_inference */ false));
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope Scope::DisabledShapeInferenceScope() {
|
||||||
|
Graph* graph = new Graph(OpRegistry::Global());
|
||||||
|
ShapeRefiner* refiner =
|
||||||
|
new ShapeRefiner(graph->versions(), graph->op_registry());
|
||||||
|
return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
|
||||||
|
/* disable_shape_inference */ true));
|
||||||
}
|
}
|
||||||
|
|
||||||
Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
|
Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
|
||||||
@ -77,7 +88,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
|
|||||||
exit_on_error_(other.impl()->exit_on_error_),
|
exit_on_error_(other.impl()->exit_on_error_),
|
||||||
kernel_label_(other.impl()->kernel_label_),
|
kernel_label_(other.impl()->kernel_label_),
|
||||||
device_(other.impl()->device_),
|
device_(other.impl()->device_),
|
||||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||||
|
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||||
|
|
||||||
Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
|
Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
|
||||||
const string& op_name)
|
const string& op_name)
|
||||||
@ -92,7 +104,8 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
|
|||||||
exit_on_error_(other.impl()->exit_on_error_),
|
exit_on_error_(other.impl()->exit_on_error_),
|
||||||
kernel_label_(other.impl()->kernel_label_),
|
kernel_label_(other.impl()->kernel_label_),
|
||||||
device_(other.impl()->device_),
|
device_(other.impl()->device_),
|
||||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||||
|
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||||
|
|
||||||
Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
|
Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
|
||||||
std::vector<Operation> control_deps, bool clear_control_deps)
|
std::vector<Operation> control_deps, bool clear_control_deps)
|
||||||
@ -113,7 +126,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
|
|||||||
exit_on_error_(other.impl()->exit_on_error_),
|
exit_on_error_(other.impl()->exit_on_error_),
|
||||||
kernel_label_(other.impl()->kernel_label_),
|
kernel_label_(other.impl()->kernel_label_),
|
||||||
device_(other.impl()->device_),
|
device_(other.impl()->device_),
|
||||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||||
|
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||||
|
|
||||||
Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
|
Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
|
||||||
: graph_(other.impl()->graph_),
|
: graph_(other.impl()->graph_),
|
||||||
@ -127,7 +141,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
|
|||||||
exit_on_error_(other.impl()->exit_on_error_),
|
exit_on_error_(other.impl()->exit_on_error_),
|
||||||
kernel_label_(other.impl()->kernel_label_),
|
kernel_label_(other.impl()->kernel_label_),
|
||||||
device_(device),
|
device_(device),
|
||||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||||
|
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||||
|
|
||||||
Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
|
Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
|
||||||
const string& op_name)
|
const string& op_name)
|
||||||
@ -142,7 +157,8 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
|
|||||||
exit_on_error_(other.impl()->exit_on_error_),
|
exit_on_error_(other.impl()->exit_on_error_),
|
||||||
kernel_label_(other.impl()->kernel_label_),
|
kernel_label_(other.impl()->kernel_label_),
|
||||||
device_(other.impl()->device_),
|
device_(other.impl()->device_),
|
||||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||||
|
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||||
|
|
||||||
Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
|
Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
|
||||||
: graph_(other.impl()->graph_),
|
: graph_(other.impl()->graph_),
|
||||||
@ -156,7 +172,8 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
|
|||||||
exit_on_error_(true),
|
exit_on_error_(true),
|
||||||
kernel_label_(other.impl()->kernel_label_),
|
kernel_label_(other.impl()->kernel_label_),
|
||||||
device_(other.impl()->device_),
|
device_(other.impl()->device_),
|
||||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||||
|
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||||
|
|
||||||
Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
|
Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
|
||||||
const string& kernel_label)
|
const string& kernel_label)
|
||||||
@ -171,7 +188,8 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
|
|||||||
exit_on_error_(other.impl()->exit_on_error_),
|
exit_on_error_(other.impl()->exit_on_error_),
|
||||||
kernel_label_(kernel_label),
|
kernel_label_(kernel_label),
|
||||||
device_(other.impl()->device_),
|
device_(other.impl()->device_),
|
||||||
colocation_constraints_(other.impl()->colocation_constraints_) {}
|
colocation_constraints_(other.impl()->colocation_constraints_),
|
||||||
|
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||||
|
|
||||||
Scope::Impl::Impl(const Scope& other, Tags::Colocate,
|
Scope::Impl::Impl(const Scope& other, Tags::Colocate,
|
||||||
const Operation& colocate_with_op, bool clear_colocations)
|
const Operation& colocate_with_op, bool clear_colocations)
|
||||||
@ -189,7 +207,8 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
|
|||||||
colocation_constraints_(
|
colocation_constraints_(
|
||||||
clear_colocations
|
clear_colocations
|
||||||
? std::unordered_set<string>()
|
? std::unordered_set<string>()
|
||||||
: other.impl()->GetColocationConstraints(colocate_with_op)) {}
|
: other.impl()->GetColocationConstraints(colocate_with_op)),
|
||||||
|
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
|
||||||
|
|
||||||
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
|
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
|
||||||
const Operation& colocate_with_op) const {
|
const Operation& colocate_with_op) const {
|
||||||
@ -404,6 +423,11 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Scope::DoShapeInference(Node* node) const {
|
||||||
|
if (impl_->disable_shape_inference_) return Status::OK();
|
||||||
|
return impl_->refiner_->AddNode(node);
|
||||||
|
}
|
||||||
|
|
||||||
class InternalScope {
|
class InternalScope {
|
||||||
public:
|
public:
|
||||||
// NewScope doesn't take ownership of the inputs.
|
// NewScope doesn't take ownership of the inputs.
|
||||||
|
@ -199,6 +199,18 @@ class Scope {
|
|||||||
// edges from the source and to the sink node, resolves back edges
|
// edges from the source and to the sink node, resolves back edges
|
||||||
// by name), and makes sure the resulting graph is valid.
|
// by name), and makes sure the resulting graph is valid.
|
||||||
Status ToGraph(Graph* g) const;
|
Status ToGraph(Graph* g) const;
|
||||||
|
|
||||||
|
// Calls AddNode() using this scope's ShapeRefiner. This exists in the public
|
||||||
|
// API to prevent custom op wrappers from needing access to shape_refiner.h or
|
||||||
|
// scope_internal.h.
|
||||||
|
// TODO(skyewm): remove this from public API
|
||||||
|
Status DoShapeInference(Node* node) const;
|
||||||
|
|
||||||
|
// Creates a new root scope that causes all DoShapeInference() calls to return
|
||||||
|
// Status::OK() (on the returned scope and any subscopes). Used for testing.
|
||||||
|
// TODO(skyewm): fix tests that still require this and eventually remove, or
|
||||||
|
// at least remove from public API
|
||||||
|
static Scope DisabledShapeInferenceScope();
|
||||||
// END_SKIP_DOXYGEN
|
// END_SKIP_DOXYGEN
|
||||||
|
|
||||||
const std::vector<Operation>& control_deps() const;
|
const std::vector<Operation>& control_deps() const;
|
||||||
|
@ -58,7 +58,8 @@ class Scope::Impl {
|
|||||||
enum class Colocate;
|
enum class Colocate;
|
||||||
};
|
};
|
||||||
|
|
||||||
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner);
|
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
|
||||||
|
bool disable_shape_inference);
|
||||||
Impl(const Scope& other, Tags::ScopeName, const string& name,
|
Impl(const Scope& other, Tags::ScopeName, const string& name,
|
||||||
bool copy_names);
|
bool copy_names);
|
||||||
Impl(const Scope& other, Tags::OpName, const string& name,
|
Impl(const Scope& other, Tags::OpName, const string& name,
|
||||||
@ -103,6 +104,10 @@ class Scope::Impl {
|
|||||||
const string kernel_label_ = "";
|
const string kernel_label_ = "";
|
||||||
const string device_ = "";
|
const string device_ = "";
|
||||||
const std::unordered_set<string> colocation_constraints_;
|
const std::unordered_set<string> colocation_constraints_;
|
||||||
|
|
||||||
|
// If true, Scope::DoShapeInference() always returns Status:OK().
|
||||||
|
// TODO(skyewm): remove this when possible
|
||||||
|
const bool disable_shape_inference_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -24,6 +25,7 @@ REGISTER_OP("ThrowAway1")
|
|||||||
.Attr("scope: int")
|
.Attr("scope: int")
|
||||||
.Attr("builder: int = 1")
|
.Attr("builder: int = 1")
|
||||||
.Attr("while: int")
|
.Attr("while: int")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Op to test keywords and reserved words in input and attr names.
|
Op to test keywords and reserved words in input and attr names.
|
||||||
|
|
||||||
@ -36,12 +38,20 @@ REGISTER_OP("ThrowAway2")
|
|||||||
.Attr("scope: int = 2")
|
.Attr("scope: int = 2")
|
||||||
.Attr("throw_away2: int = 2")
|
.Attr("throw_away2: int = 2")
|
||||||
.Attr("attrs: int = 4")
|
.Attr("attrs: int = 4")
|
||||||
.Attr("node: int = 4");
|
.Attr("node: int = 4")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("ThrowAway3").Output("node: int32");
|
REGISTER_OP("ThrowAway3")
|
||||||
|
.Output("node: int32")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("ThrowAway4").Input("node: int32");
|
REGISTER_OP("ThrowAway4")
|
||||||
|
.Input("node: int32")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
REGISTER_OP("ThrowAway5").Output("foo: int32").Attr("node: int = 4");
|
REGISTER_OP("ThrowAway5")
|
||||||
|
.Output("foo: int32")
|
||||||
|
.Attr("node: int = 4")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -34,7 +34,9 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
|
|||||||
.Attr("dtype", val.tensor.dtype());
|
.Attr("dtype", val.tensor.dtype());
|
||||||
scope.UpdateBuilder(&builder);
|
scope.UpdateBuilder(&builder);
|
||||||
scope.UpdateStatus(builder.Finalize(graph, &ret));
|
scope.UpdateStatus(builder.Finalize(graph, &ret));
|
||||||
|
if (!scope.ok()) return Output();
|
||||||
|
|
||||||
|
scope.UpdateStatus(scope.DoShapeInference(ret));
|
||||||
if (!scope.ok()) return Output();
|
if (!scope.ok()) return Output();
|
||||||
|
|
||||||
return Output(ret);
|
return Output(ret);
|
||||||
|
@ -56,6 +56,8 @@ Output Const(const Scope& scope, const Input::Initializer& val) {
|
|||||||
scope.UpdateBuilder(&cast_builder);
|
scope.UpdateBuilder(&cast_builder);
|
||||||
Node* ret;
|
Node* ret;
|
||||||
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
|
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
|
||||||
|
if (!scope.ok()) return Output();
|
||||||
|
scope.UpdateStatus(scope.DoShapeInference(ret));
|
||||||
return Output(ret, 0);
|
return Output(ret, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -26,6 +27,7 @@ REGISTER_OP("XlaWhile")
|
|||||||
.Attr("cond: func")
|
.Attr("cond: func")
|
||||||
.Attr("body: func")
|
.Attr("body: func")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
output = input; While (Cond(output)) { output = Body(output) }
|
output = input; While (Cond(output)) { output = Body(output) }
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -22,6 +23,7 @@ REGISTER_OP("_XLASend")
|
|||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
.Attr("tensor_name: string")
|
.Attr("tensor_name: string")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Sends the named tensor to another XLA computation.
|
Sends the named tensor to another XLA computation.
|
||||||
|
|
||||||
@ -35,6 +37,7 @@ REGISTER_OP("_XLARecv")
|
|||||||
.Attr("tensor_name: string")
|
.Attr("tensor_name: string")
|
||||||
.Attr("shape: shape")
|
.Attr("shape: shape")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Receives the named tensor from another XLA computation.
|
Receives the named tensor from another XLA computation.
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
@ -76,6 +77,8 @@ class DummyReadResourceCC {
|
|||||||
scope.UpdateBuilder(&builder);
|
scope.UpdateBuilder(&builder);
|
||||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||||
if (!scope.ok()) return;
|
if (!scope.ok()) return;
|
||||||
|
scope.UpdateStatus(scope.DoShapeInference(ret));
|
||||||
|
if (!scope.ok()) return;
|
||||||
this->output_ = Output(ret, 0);
|
this->output_ = Output(ret, 0);
|
||||||
}
|
}
|
||||||
Node* node() const { return output_.node(); }
|
Node* node() const { return output_.node(); }
|
||||||
@ -86,6 +89,7 @@ class DummyReadResourceCC {
|
|||||||
REGISTER_OP("DummyReadResource")
|
REGISTER_OP("DummyReadResource")
|
||||||
.Input("input: int32")
|
.Input("input: int32")
|
||||||
.Output("output: int32")
|
.Output("output: int32")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
A dummy Op.
|
A dummy Op.
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
@ -67,6 +68,7 @@ REGISTER_OP("TPUReplicate")
|
|||||||
.Input("broadcast_inputs: Tbroadcast_inputs")
|
.Input("broadcast_inputs: Tbroadcast_inputs")
|
||||||
.Input("variables: NumVariables * resource")
|
.Input("variables: NumVariables * resource")
|
||||||
.Output("outputs: output_types")
|
.Output("outputs: output_types")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Runs replicated computations on a distributed TPU system.
|
Runs replicated computations on a distributed TPU system.
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -154,7 +155,10 @@ global_tpu_array: A two-dimensional array. For each host (the outer
|
|||||||
dimension) the array lists the global ids of the TPUs on that host.
|
dimension) the array lists the global ids of the TPUs on that host.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
|
REGISTER_OP("_ShutdownDistributedTPU")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
|
.Doc(R"doc(
|
||||||
An op that shuts down a running distributed TPU system. The Op returns
|
An op that shuts down a running distributed TPU system. The Op returns
|
||||||
an error if no system is running. This Op must be run on the same
|
an error if no system is running. This Op must be run on the same
|
||||||
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
|
TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run
|
||||||
@ -184,6 +188,7 @@ tpu_ids: A vector containing the global TPU id of each TPU on the host.
|
|||||||
REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
|
REGISTER_OP("_DisconnectHostFromDistributedTPUSystem")
|
||||||
.Output("number_of_tpu_chips: int32")
|
.Output("number_of_tpu_chips: int32")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
An op that disconnects the TPUs on a host from a running distributed
|
An op that disconnects the TPUs on a host from a running distributed
|
||||||
TPU system.
|
TPU system.
|
||||||
@ -196,6 +201,7 @@ REGISTER_OP("ConfigureDistributedTPU")
|
|||||||
.Output("global_tpu_array: int32")
|
.Output("global_tpu_array: int32")
|
||||||
.Attr("embedding_config: string = ''")
|
.Attr("embedding_config: string = ''")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
An op that sets up the centralized structures for a distributed TPU
|
An op that sets up the centralized structures for a distributed TPU
|
||||||
system.
|
system.
|
||||||
@ -205,7 +211,10 @@ dimension) the array lists the global ids of the TPUs on that host.
|
|||||||
embedding_config: Internal use.
|
embedding_config: Internal use.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc(
|
REGISTER_OP("ShutdownDistributedTPU")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
|
.Doc(R"doc(
|
||||||
An op that shuts down a running distributed TPU system. The Op returns
|
An op that shuts down a running distributed TPU system. The Op returns
|
||||||
an error if no system is running.
|
an error if no system is running.
|
||||||
)doc");
|
)doc");
|
||||||
|
@ -282,6 +282,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
|||||||
Status status;
|
Status status;
|
||||||
Node* times_two = s.graph()->AddNode(def, &status);
|
Node* times_two = s.graph()->AddNode(def, &status);
|
||||||
TF_ASSERT_OK(status);
|
TF_ASSERT_OK(status);
|
||||||
|
TF_ASSERT_OK(s.DoShapeInference(times_two));
|
||||||
s.graph()->AddEdge(c.node(), 0, times_two, 0);
|
s.graph()->AddEdge(c.node(), 0, times_two, 0);
|
||||||
|
|
||||||
auto times_two_send =
|
auto times_two_send =
|
||||||
@ -297,7 +298,10 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
|||||||
EXPECT_FALSE(was_mutated);
|
EXPECT_FALSE(was_mutated);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("ConstantFoldingTestOp").Input("a: int64").Output("b: int64");
|
REGISTER_OP("ConstantFoldingTestOp")
|
||||||
|
.Input("a: int64")
|
||||||
|
.Output("b: int64")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
|
TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
|
||||||
Graph g(OpRegistry::Global());
|
Graph g(OpRegistry::Global());
|
||||||
@ -312,6 +316,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
|
|||||||
Status status;
|
Status status;
|
||||||
Node* non_cpu = s.graph()->AddNode(def, &status);
|
Node* non_cpu = s.graph()->AddNode(def, &status);
|
||||||
TF_ASSERT_OK(status);
|
TF_ASSERT_OK(status);
|
||||||
|
TF_ASSERT_OK(s.DoShapeInference(non_cpu));
|
||||||
|
|
||||||
auto non_cpu_send =
|
auto non_cpu_send =
|
||||||
ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu),
|
ops::_Send(s.WithOpName("non_cpu_send"), Output(non_cpu),
|
||||||
|
@ -284,6 +284,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name,
|
|||||||
Status status;
|
Status status;
|
||||||
Node* n = scope->graph()->AddNode(def, &status);
|
Node* n = scope->graph()->AddNode(def, &status);
|
||||||
TF_CHECK_OK(status);
|
TF_CHECK_OK(status);
|
||||||
|
TF_CHECK_OK(scope->DoShapeInference(n));
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i);
|
scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i);
|
||||||
}
|
}
|
||||||
@ -989,7 +990,7 @@ TEST(OptimizationTest, RemoveDeadNodes) {
|
|||||||
|
|
||||||
GraphDef expected;
|
GraphDef expected;
|
||||||
{
|
{
|
||||||
Scope s = Scope::NewRootScope();
|
Scope s = Scope::DisabledShapeInferenceScope();
|
||||||
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
||||||
auto o = ops::Const(s.WithOpName("o"), 1);
|
auto o = ops::Const(s.WithOpName("o"), 1);
|
||||||
auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
|
auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
|
||||||
@ -1070,7 +1071,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
|
|||||||
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
|
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
|
||||||
|
|
||||||
{
|
{
|
||||||
Scope s = Scope::NewRootScope();
|
Scope s = Scope::DisabledShapeInferenceScope();
|
||||||
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
||||||
auto o = ops::Const(s.WithOpName("o"), 1);
|
auto o = ops::Const(s.WithOpName("o"), 1);
|
||||||
auto a = ops::Square(s.WithOpName("a"), x);
|
auto a = ops::Square(s.WithOpName("a"), x);
|
||||||
@ -1087,7 +1088,7 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
Scope s = Scope::NewRootScope();
|
Scope s = Scope::DisabledShapeInferenceScope();
|
||||||
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
|
||||||
auto o = ops::Const(s.WithOpName("o"), 1);
|
auto o = ops::Const(s.WithOpName("o"), 1);
|
||||||
auto a = ops::Square(s.WithOpName("a"), x);
|
auto a = ops::Square(s.WithOpName("a"), x);
|
||||||
@ -1137,7 +1138,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
|
|||||||
{{"o", "o:sum"}});
|
{{"o", "o:sum"}});
|
||||||
|
|
||||||
{
|
{
|
||||||
Scope scope = Scope::NewRootScope();
|
Scope scope = Scope::DisabledShapeInferenceScope();
|
||||||
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
|
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
|
||||||
auto zero = ops::Const(scope.WithOpName("zero"), 0);
|
auto zero = ops::Const(scope.WithOpName("zero"), 0);
|
||||||
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
|
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
|
||||||
@ -1222,7 +1223,7 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
|||||||
{{"o", "o:sum"}});
|
{{"o", "o:sum"}});
|
||||||
|
|
||||||
{
|
{
|
||||||
Scope s = Scope::NewRootScope();
|
Scope s = Scope::DisabledShapeInferenceScope();
|
||||||
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
|
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
|
||||||
auto dummy = ops::Const(s.WithOpName("dummy"), 0);
|
auto dummy = ops::Const(s.WithOpName("dummy"), 0);
|
||||||
auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),
|
auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),
|
||||||
|
@ -64,7 +64,7 @@ TEST_F(GpuStreamUtilTest, EmptyGraph) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
|
TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
|
||||||
auto root = Scope::NewRootScope().ExitOnError();
|
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||||
ops::MatMul(root, {}, {});
|
ops::MatMul(root, {}, {});
|
||||||
Graph g(OpRegistry::Global());
|
Graph g(OpRegistry::Global());
|
||||||
TF_ASSERT_OK(root.ToGraph(&g));
|
TF_ASSERT_OK(root.ToGraph(&g));
|
||||||
@ -83,7 +83,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
|
TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
|
||||||
auto root = Scope::NewRootScope().ExitOnError();
|
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||||
ops::MatMul(root, {}, {});
|
ops::MatMul(root, {}, {});
|
||||||
Graph g(OpRegistry::Global());
|
Graph g(OpRegistry::Global());
|
||||||
TF_ASSERT_OK(root.ToGraph(&g));
|
TF_ASSERT_OK(root.ToGraph(&g));
|
||||||
@ -104,7 +104,7 @@ TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GpuStreamUtilTest, StreamOverrides) {
|
TEST_F(GpuStreamUtilTest, StreamOverrides) {
|
||||||
auto root = Scope::NewRootScope().ExitOnError();
|
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||||
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
|
ops::_Recv(root.WithOpName("input"), DT_FLOAT, "input", "/cpu:0", 0,
|
||||||
"/device:GPU:0");
|
"/device:GPU:0");
|
||||||
Output n = ops::MatMul(root, {}, {});
|
Output n = ops::MatMul(root, {}, {});
|
||||||
|
@ -882,7 +882,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
|
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
|
||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::DisabledShapeInferenceScope();
|
||||||
Node* scalar_non_const;
|
Node* scalar_non_const;
|
||||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||||
.Finalize(root.graph(), &scalar_non_const));
|
.Finalize(root.graph(), &scalar_non_const));
|
||||||
@ -914,7 +914,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
|
TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
|
||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::DisabledShapeInferenceScope();
|
||||||
Node* scalar_non_const;
|
Node* scalar_non_const;
|
||||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
|
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64")
|
||||||
.Finalize(root.graph(), &scalar_non_const));
|
.Finalize(root.graph(), &scalar_non_const));
|
||||||
@ -997,7 +997,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::DisabledShapeInferenceScope();
|
||||||
Graph* g = root.graph();
|
Graph* g = root.graph();
|
||||||
Node* partial_1;
|
Node* partial_1;
|
||||||
Node* partial_2;
|
Node* partial_2;
|
||||||
@ -1034,7 +1034,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
|
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
|
||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::DisabledShapeInferenceScope();
|
||||||
Graph* g = root.graph();
|
Graph* g = root.graph();
|
||||||
Node* scalar_non_const;
|
Node* scalar_non_const;
|
||||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||||
@ -1077,7 +1077,7 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
|
TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
|
||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::DisabledShapeInferenceScope();
|
||||||
Graph* g = root.graph();
|
Graph* g = root.graph();
|
||||||
Node* scalar_non_const;
|
Node* scalar_non_const;
|
||||||
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32")
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/math_ops.h"
|
#include "tensorflow/cc/ops/math_ops.h"
|
||||||
#include "tensorflow/cc/ops/random_ops.h"
|
#include "tensorflow/cc/ops/random_ops.h"
|
||||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
@ -141,9 +142,17 @@ void CheckLoopConstruction(const GraphDef& graph_def) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("FloatInput").Output("o: float");
|
REGISTER_OP("FloatInput")
|
||||||
REGISTER_OP("BoolInput").Output("o: bool");
|
.Output("o: float")
|
||||||
REGISTER_OP("Combine").Input("a: float").Input("b: float").Output("o: float");
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
REGISTER_OP("BoolInput")
|
||||||
|
.Output("o: bool")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
REGISTER_OP("Combine")
|
||||||
|
.Input("a: float")
|
||||||
|
.Input("b: float")
|
||||||
|
.Output("o: float")
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape);
|
||||||
|
|
||||||
Output ConstructOp(const Scope& scope, const string& op_type,
|
Output ConstructOp(const Scope& scope, const string& op_type,
|
||||||
const gtl::ArraySlice<Input>& inputs) {
|
const gtl::ArraySlice<Input>& inputs) {
|
||||||
@ -158,6 +167,8 @@ Output ConstructOp(const Scope& scope, const string& op_type,
|
|||||||
Node* ret;
|
Node* ret;
|
||||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||||
if (!scope.ok()) return Output();
|
if (!scope.ok()) return Output();
|
||||||
|
scope.UpdateStatus(scope.DoShapeInference(ret));
|
||||||
|
if (!scope.ok()) return Output();
|
||||||
return Output(ret);
|
return Output(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ namespace {
|
|||||||
class AutoParallelTest : public ::testing::Test {};
|
class AutoParallelTest : public ::testing::Test {};
|
||||||
|
|
||||||
TEST_F(AutoParallelTest, SimpleParallel) {
|
TEST_F(AutoParallelTest, SimpleParallel) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
|
||||||
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
|
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
|
||||||
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
|
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
|
||||||
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
|
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
|
||||||
|
@ -35,7 +35,7 @@ namespace tensorflow {
|
|||||||
using namespace ops; // NOLINT(build/namespaces)
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
TEST(EncodeWavOpTest, EncodeWavTest) {
|
TEST(EncodeWavOpTest, EncodeWavTest) {
|
||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::DisabledShapeInferenceScope();
|
||||||
|
|
||||||
Tensor audio_tensor(DT_FLOAT, {4, 2});
|
Tensor audio_tensor(DT_FLOAT, {4, 2});
|
||||||
test::FillValues<float>(
|
test::FillValues<float>(
|
||||||
|
@ -88,7 +88,7 @@ class FuzzSession {
|
|||||||
}
|
}
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
|
|
||||||
Scope root = Scope::NewRootScope().ExitOnError();
|
Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
session_ = std::unique_ptr<Session>(NewSession(options));
|
session_ = std::unique_ptr<Session>(NewSession(options));
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ TEST(ImmutableConstantOpTest, ExecutionError) {
|
|||||||
const TensorShape kBadTensorShape({40, 100});
|
const TensorShape kBadTensorShape({40, 100});
|
||||||
const TensorShape kTestTensorShapeT({1, 4});
|
const TensorShape kTestTensorShapeT({1, 4});
|
||||||
|
|
||||||
auto root = Scope::NewRootScope().ExitOnError();
|
auto root = Scope::DisabledShapeInferenceScope().ExitOnError();
|
||||||
auto node1 =
|
auto node1 =
|
||||||
ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test:///2");
|
ops::ImmutableConst(root, DT_FLOAT, kBadTensorShape, "test:///2");
|
||||||
auto node2 =
|
auto node2 =
|
||||||
|
@ -35,7 +35,7 @@ namespace tensorflow {
|
|||||||
using namespace ops; // NOLINT(build/namespaces)
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
TEST(MfccOpTest, SimpleTest) {
|
TEST(MfccOpTest, SimpleTest) {
|
||||||
Scope root = Scope::NewRootScope();
|
Scope root = Scope::DisabledShapeInferenceScope();
|
||||||
|
|
||||||
Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513}));
|
Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513}));
|
||||||
test::FillIota<float>(&spectrogram_tensor, 1.0f);
|
test::FillIota<float>(&spectrogram_tensor, 1.0f);
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -26,6 +27,7 @@ REGISTER_OP("_Send")
|
|||||||
.Attr("recv_device: string")
|
.Attr("recv_device: string")
|
||||||
.Attr("client_terminated: bool = false")
|
.Attr("client_terminated: bool = false")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Sends the named tensor from send_device to recv_device.
|
Sends the named tensor from send_device to recv_device.
|
||||||
|
|
||||||
@ -49,6 +51,7 @@ REGISTER_OP("_Recv")
|
|||||||
.Attr("recv_device: string")
|
.Attr("recv_device: string")
|
||||||
.Attr("client_terminated: bool = false")
|
.Attr("client_terminated: bool = false")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Receives the named tensor from send_device on recv_device.
|
Receives the named tensor from send_device on recv_device.
|
||||||
|
|
||||||
@ -72,6 +75,7 @@ REGISTER_OP("_HostSend")
|
|||||||
.Attr("recv_device: string")
|
.Attr("recv_device: string")
|
||||||
.Attr("client_terminated: bool = false")
|
.Attr("client_terminated: bool = false")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Sends the named tensor from send_device to recv_device.
|
Sends the named tensor from send_device to recv_device.
|
||||||
|
|
||||||
@ -98,6 +102,7 @@ REGISTER_OP("_HostRecv")
|
|||||||
.Attr("recv_device: string")
|
.Attr("recv_device: string")
|
||||||
.Attr("client_terminated: bool = false")
|
.Attr("client_terminated: bool = false")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Receives the named tensor from send_device on recv_device.
|
Receives the named tensor from send_device on recv_device.
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ class FakeQuantizeTrainingTest : public ::testing::Test {};
|
|||||||
// TODO(suharshs): Once we implement the fake_quantize_training transform
|
// TODO(suharshs): Once we implement the fake_quantize_training transform
|
||||||
// using the GTT, write proper tests of the transform here.
|
// using the GTT, write proper tests of the transform here.
|
||||||
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
|
TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
|
||||||
auto root = tensorflow::Scope::NewRootScope();
|
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
|
||||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
Tensor a_data(DT_FLOAT, TensorShape());
|
Tensor a_data(DT_FLOAT, TensorShape());
|
||||||
|
@ -40,7 +40,7 @@ class QuantizeWeightsTest : public ::testing::Test {
|
|||||||
const TensorShape& weight_shape,
|
const TensorShape& weight_shape,
|
||||||
std::initializer_list<float> weight_values,
|
std::initializer_list<float> weight_values,
|
||||||
GraphDef* original_graph_def) {
|
GraphDef* original_graph_def) {
|
||||||
auto root = tensorflow::Scope::NewRootScope();
|
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
|
||||||
|
|
||||||
Tensor input_data(DT_FLOAT, input_shape);
|
Tensor input_data(DT_FLOAT, input_shape);
|
||||||
test::FillValues<float>(&input_data, input_values);
|
test::FillValues<float>(&input_data, input_values);
|
||||||
|
@ -622,7 +622,7 @@ class TransformUtilsTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TestRenameNodeInputsWithWildcard() {
|
void TestRenameNodeInputsWithWildcard() {
|
||||||
auto root = tensorflow::Scope::NewRootScope();
|
auto root = tensorflow::Scope::DisabledShapeInferenceScope();
|
||||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
const int width = 10;
|
const int width = 10;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user