C++ API: Added a Const constructor for non-empty const supporting type cast.
Fixes #3752 Change: 130113000
This commit is contained in:
parent
67734a1df6
commit
5ce2567953
@ -226,4 +226,49 @@ TEST(CCOpTest, ColocateWith) {
|
||||
EXPECT_TRUE(attrs.find("_class") == attrs.end());
|
||||
}
|
||||
|
||||
TEST(CCOpTest, TemplatedConst) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c1 = ops::Const<float>(root, {{3, 2}, {-1, 0}});
|
||||
TF_EXPECT_OK(root.status());
|
||||
|
||||
Tensor out;
|
||||
GetTensor(root, c1, &out);
|
||||
test::ExpectTensorEqual<float>(
|
||||
out, test::AsTensor<float>({3.f, 2.f, -1.f, 0.f}, {2, 2}));
|
||||
|
||||
auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
|
||||
GetTensor(root, c2, &out);
|
||||
test::ExpectTensorEqual<string>(
|
||||
out, test::AsTensor<string>({"this", "is", "a", "constant"}, {4, 1}));
|
||||
}
|
||||
|
||||
TEST(CCOpTest, EmptyConst) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
auto c1 = ops::Const(root, {});
|
||||
TF_CHECK_OK(root.status());
|
||||
|
||||
Tensor out;
|
||||
GetTensor(root, c1, &out);
|
||||
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {0}));
|
||||
|
||||
auto c2 = ops::Const(root, {{}});
|
||||
TF_CHECK_OK(root.status());
|
||||
GetTensor(root, c2, &out);
|
||||
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {1, 0}));
|
||||
|
||||
auto c3 = ops::Const(root, {{{}, {}}});
|
||||
TF_CHECK_OK(root.status());
|
||||
GetTensor(root, c3, &out);
|
||||
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {1, 2, 0}));
|
||||
|
||||
auto c4 = ops::Const<int>(root, {{{}}});
|
||||
TF_CHECK_OK(root.status());
|
||||
GetTensor(root, c4, &out);
|
||||
test::ExpectTensorEqual<int>(out, Tensor(DT_INT32, {1, 1, 0}));
|
||||
|
||||
ops::Const(root, {{}, {{}}});
|
||||
EXPECT_FALSE(root.status().ok());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -25,22 +25,35 @@ namespace ops {
|
||||
|
||||
Output Const(const Scope& scope, const Input::Initializer& val);
|
||||
|
||||
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
|
||||
|
||||
template <typename T>
|
||||
Output Const(const Scope& scope, const Input::Initializer& val) {
|
||||
auto orig_const_output = Const(scope, val);
|
||||
if (!scope.ok()) return Output();
|
||||
if (!val.status.ok()) {
|
||||
scope.UpdateStatus(val.status);
|
||||
return Output();
|
||||
}
|
||||
|
||||
typedef typename Input::Initializer::RealType<T>::type DstT;
|
||||
if (val.tensor.NumElements() > 0) {
|
||||
// TODO(keveman): Implement the in-situ cast.
|
||||
scope.UpdateStatus(errors::Unimplemented(
|
||||
"Explict cast of a non-empty tensor not implemented yet"));
|
||||
return Output();
|
||||
|
||||
if (val.tensor.dtype() == DataTypeToEnum<DstT>::v()) {
|
||||
return orig_const_output;
|
||||
}
|
||||
Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
|
||||
return Const(scope, Input::Initializer(t));
|
||||
if (val.tensor.NumElements() == 0) {
|
||||
Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
|
||||
return Const(scope, Input::Initializer(t));
|
||||
}
|
||||
|
||||
// TODO(keveman): Refactor Cast op's kernel implementation such that the code
|
||||
// can be directly called here instead of adding the Cast op to the graph.
|
||||
auto orig_const = AsNodeOut(scope, orig_const_output);
|
||||
const auto cast_op_name = scope.GetUniqueNameForOp("Cast");
|
||||
|
||||
auto cast_builder = NodeBuilder(cast_op_name, "Cast")
|
||||
.Input(orig_const)
|
||||
.Attr("DstT", DataTypeToEnum<DstT>::v());
|
||||
scope.UpdateBuilder(&cast_builder);
|
||||
Node* ret;
|
||||
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
|
||||
return Output(ret, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -54,8 +67,6 @@ Output Const(const Scope& scope, const std::initializer_list<T>& v,
|
||||
return Const(scope, Input::Initializer(v, shape));
|
||||
}
|
||||
|
||||
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
|
||||
|
||||
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
|
||||
const InputList& inp);
|
||||
|
||||
|
@ -125,4 +125,13 @@ TEST(ConstOpTest, Names) {
|
||||
EXPECT_EQ(c_y_1.node()->name(), "c/y_1");
|
||||
}
|
||||
|
||||
TEST(ConstOpTest, TemplatedConst) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c1 = ops::Const<int>(root, {1, 2});
|
||||
ExpectTypeAndShape(c1.node(), DT_INT32, {2});
|
||||
|
||||
auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
|
||||
ExpectTypeAndShape(c2.node(), DT_STRING, {4, 1});
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user