C++ API: Added a Const constructor for non-empty const supporting type cast.

Fixes #3752
Change: 130113000
This commit is contained in:
Manjunath Kudlur 2016-08-12 09:19:52 -08:00 committed by Gunhan Gulsoy
parent 67734a1df6
commit 5ce2567953
3 changed files with 78 additions and 13 deletions

View File

@ -226,4 +226,49 @@ TEST(CCOpTest, ColocateWith) {
EXPECT_TRUE(attrs.find("_class") == attrs.end());
}
TEST(CCOpTest, TemplatedConst) {
Scope root = Scope::NewRootScope();
auto c1 = ops::Const<float>(root, {{3, 2}, {-1, 0}});
TF_EXPECT_OK(root.status());
Tensor out;
GetTensor(root, c1, &out);
test::ExpectTensorEqual<float>(
out, test::AsTensor<float>({3.f, 2.f, -1.f, 0.f}, {2, 2}));
auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
GetTensor(root, c2, &out);
test::ExpectTensorEqual<string>(
out, test::AsTensor<string>({"this", "is", "a", "constant"}, {4, 1}));
}
TEST(CCOpTest, EmptyConst) {
Scope root = Scope::NewRootScope();
auto c1 = ops::Const(root, {});
TF_CHECK_OK(root.status());
Tensor out;
GetTensor(root, c1, &out);
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {0}));
auto c2 = ops::Const(root, {{}});
TF_CHECK_OK(root.status());
GetTensor(root, c2, &out);
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {1, 0}));
auto c3 = ops::Const(root, {{{}, {}}});
TF_CHECK_OK(root.status());
GetTensor(root, c3, &out);
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {1, 2, 0}));
auto c4 = ops::Const<int>(root, {{{}}});
TF_CHECK_OK(root.status());
GetTensor(root, c4, &out);
test::ExpectTensorEqual<int>(out, Tensor(DT_INT32, {1, 1, 0}));
ops::Const(root, {{}, {{}}});
EXPECT_FALSE(root.status().ok());
}
} // namespace tensorflow

View File

@ -25,22 +25,35 @@ namespace ops {
Output Const(const Scope& scope, const Input::Initializer& val);
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
template <typename T>
Output Const(const Scope& scope, const Input::Initializer& val) {
auto orig_const_output = Const(scope, val);
if (!scope.ok()) return Output();
if (!val.status.ok()) {
scope.UpdateStatus(val.status);
return Output();
}
typedef typename Input::Initializer::RealType<T>::type DstT;
if (val.tensor.NumElements() > 0) {
// TODO(keveman): Implement the in-situ cast.
scope.UpdateStatus(errors::Unimplemented(
"Explict cast of a non-empty tensor not implemented yet"));
return Output();
if (val.tensor.dtype() == DataTypeToEnum<DstT>::v()) {
return orig_const_output;
}
Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
return Const(scope, Input::Initializer(t));
if (val.tensor.NumElements() == 0) {
Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
return Const(scope, Input::Initializer(t));
}
// TODO(keveman): Refactor Cast op's kernel implementation such that the code
// can be directly called here instead of adding the Cast op to the graph.
auto orig_const = AsNodeOut(scope, orig_const_output);
const auto cast_op_name = scope.GetUniqueNameForOp("Cast");
auto cast_builder = NodeBuilder(cast_op_name, "Cast")
.Input(orig_const)
.Attr("DstT", DataTypeToEnum<DstT>::v());
scope.UpdateBuilder(&cast_builder);
Node* ret;
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
return Output(ret, 0);
}
template <typename T>
@ -54,8 +67,6 @@ Output Const(const Scope& scope, const std::initializer_list<T>& v,
return Const(scope, Input::Initializer(v, shape));
}
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
const InputList& inp);

View File

@ -125,4 +125,13 @@ TEST(ConstOpTest, Names) {
EXPECT_EQ(c_y_1.node()->name(), "c/y_1");
}
TEST(ConstOpTest, TemplatedConst) {
Scope root = Scope::NewRootScope();
auto c1 = ops::Const<int>(root, {1, 2});
ExpectTypeAndShape(c1.node(), DT_INT32, {2});
auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
ExpectTypeAndShape(c2.node(), DT_STRING, {4, 1});
}
} // namespace tensorflow