From 5ce2567953b3bde1665e82a75bebf35066111b2c Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 12 Aug 2016 09:19:52 -0800 Subject: [PATCH] C++ API: Added a Const constructor for non-empty const supporting type cast. Fixes #3752 Change: 130113000 --- tensorflow/cc/framework/cc_ops_test.cc | 45 ++++++++++++++++++++++++++ tensorflow/cc/ops/const_op.h | 37 +++++++++++++-------- tensorflow/cc/ops/const_op_test.cc | 9 ++++++ 3 files changed, 78 insertions(+), 13 deletions(-) diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index e1bc666de69..cecc633ca1d 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -226,4 +226,49 @@ TEST(CCOpTest, ColocateWith) { EXPECT_TRUE(attrs.find("_class") == attrs.end()); } +TEST(CCOpTest, TemplatedConst) { + Scope root = Scope::NewRootScope(); + auto c1 = ops::Const(root, {{3, 2}, {-1, 0}}); + TF_EXPECT_OK(root.status()); + + Tensor out; + GetTensor(root, c1, &out); + test::ExpectTensorEqual( + out, test::AsTensor({3.f, 2.f, -1.f, 0.f}, {2, 2})); + + auto c2 = ops::Const(root, {{"this"}, {"is"}, {"a"}, {"constant"}}); + GetTensor(root, c2, &out); + test::ExpectTensorEqual( + out, test::AsTensor({"this", "is", "a", "constant"}, {4, 1})); +} + +TEST(CCOpTest, EmptyConst) { + Scope root = Scope::NewRootScope(); + + auto c1 = ops::Const(root, {}); + TF_CHECK_OK(root.status()); + + Tensor out; + GetTensor(root, c1, &out); + test::ExpectTensorEqual(out, Tensor(DT_FLOAT, {0})); + + auto c2 = ops::Const(root, {{}}); + TF_CHECK_OK(root.status()); + GetTensor(root, c2, &out); + test::ExpectTensorEqual(out, Tensor(DT_FLOAT, {1, 0})); + + auto c3 = ops::Const(root, {{{}, {}}}); + TF_CHECK_OK(root.status()); + GetTensor(root, c3, &out); + test::ExpectTensorEqual(out, Tensor(DT_FLOAT, {1, 2, 0})); + + auto c4 = ops::Const(root, {{{}}}); + TF_CHECK_OK(root.status()); + GetTensor(root, c4, &out); + test::ExpectTensorEqual(out, Tensor(DT_INT32, {1, 1, 0})); + + ops::Const(root, {{}, {{}}}); + EXPECT_FALSE(root.status().ok()); +} + } // namespace tensorflow diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index 75844d124d9..8976a24edc6 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -25,22 +25,35 @@ namespace ops { Output Const(const Scope& scope, const Input::Initializer& val); +NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); + template Output Const(const Scope& scope, const Input::Initializer& val) { + auto orig_const_output = Const(scope, val); if (!scope.ok()) return Output(); - if (!val.status.ok()) { - scope.UpdateStatus(val.status); - return Output(); - } + typedef typename Input::Initializer::RealType::type DstT; - if (val.tensor.NumElements() > 0) { - // TODO(keveman): Implement the in-situ cast. - scope.UpdateStatus(errors::Unimplemented( - "Explict cast of a non-empty tensor not implemented yet")); - return Output(); + + if (val.tensor.dtype() == DataTypeToEnum::v()) { + return orig_const_output; } - Tensor t(DataTypeToEnum::v(), val.tensor.shape()); - return Const(scope, Input::Initializer(t)); + if (val.tensor.NumElements() == 0) { + Tensor t(DataTypeToEnum::v(), val.tensor.shape()); + return Const(scope, Input::Initializer(t)); + } + + // TODO(keveman): Refactor Cast op's kernel implementation such that the code + // can be directly called here instead of adding the Cast op to the graph. + auto orig_const = AsNodeOut(scope, orig_const_output); + const auto cast_op_name = scope.GetUniqueNameForOp("Cast"); + + auto cast_builder = NodeBuilder(cast_op_name, "Cast") + .Input(orig_const) + .Attr("DstT", DataTypeToEnum::v()); + scope.UpdateBuilder(&cast_builder); + Node* ret; + scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret)); + return Output(ret, 0); } template @@ -54,8 +67,6 @@ Output Const(const Scope& scope, const std::initializer_list& v, return Const(scope, Input::Initializer(v, shape)); } -NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); - std::vector AsNodeOutList(const Scope& scope, const InputList& inp); diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index a56b66c1ccd..5a4770f879f 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -125,4 +125,13 @@ TEST(ConstOpTest, Names) { EXPECT_EQ(c_y_1.node()->name(), "c/y_1"); } +TEST(ConstOpTest, TemplatedConst) { + Scope root = Scope::NewRootScope(); + auto c1 = ops::Const(root, {1, 2}); + ExpectTypeAndShape(c1.node(), DT_INT32, {2}); + + auto c2 = ops::Const(root, {{"this"}, {"is"}, {"a"}, {"constant"}}); + ExpectTypeAndShape(c2.node(), DT_STRING, {4, 1}); +} + } // namespace tensorflow