Support fp16 conversion for conversion of Square, Relu and LeakyRelu.

This also fix a bug in GatherV2 converter where it uses the wrong data type attribute name.
Will enable it for more converters later.

PiperOrigin-RevId: 238485945
This commit is contained in:
Guangda Lai 2019-03-14 11:40:11 -07:00 committed by TensorFlower Gardener
parent 7239d65880
commit 762093140b
3 changed files with 187 additions and 128 deletions

View File

@ -326,6 +326,7 @@ tf_cuda_cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
"//tensorflow/cc:scope", "//tensorflow/cc:scope",
"//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/costs:graph_properties",

View File

@ -116,6 +116,88 @@ inline Status ConvertDType(DataType tf_dtype, nvinfer1::DataType* trt_dtype) {
return Status::OK(); return Status::OK();
} }
class TFAttrs {
public:
explicit TFAttrs(const NodeDef& tf_node) {
for (const auto& attr : tf_node.attr()) {
attrs_.insert({attr.first, &attr.second});
}
}
bool count(const string& key) const { return attrs_.count(key); }
AttrValue const* at(const string& key) const {
if (!attrs_.count(key)) {
LOG(FATAL) << "Attribute not found: " << key;
}
return attrs_.at(key);
}
template <typename T>
T get(const string& key) const;
template <typename T>
T get(const string& key, const T& default_value) const {
return attrs_.count(key) ? this->get<T>(key) : default_value;
}
std::vector<string> GetAllAttrKeys() const {
std::vector<string> attr_list;
for (const auto& attr_item : attrs_) {
attr_list.emplace_back(attr_item.first);
}
return attr_list;
}
private:
typedef std::map<string, AttrValue const*> AttrMap;
AttrMap attrs_;
};
template <>
string TFAttrs::get<string>(const string& key) const {
return this->at(key)->s();
}
template <>
std::vector<int64> TFAttrs::get<std::vector<int64>>(const string& key) const {
auto attr = this->at(key)->list().i();
return std::vector<int64>(attr.begin(), attr.end());
}
template <>
std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
auto attr = this->at(key)->list().f();
return std::vector<float>(attr.begin(), attr.end());
}
template <>
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
return trt_dtype;
}
template <>
DataType TFAttrs::get<DataType>(const string& key) const {
return this->at(key)->type();
}
template <>
float TFAttrs::get<float>(const string& key) const {
return this->at(key)->f();
}
template <>
bool TFAttrs::get<bool>(const string& key) const {
return this->at(key)->b();
}
template <>
int64 TFAttrs::get<int64>(const string& key) const {
return this->at(key)->i();
}
template <typename TensorShapeType> template <typename TensorShapeType>
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
bool ignore_first_dim) { bool ignore_first_dim) {
@ -379,17 +461,35 @@ nvinfer1::ITensor* Converter::CreateConstantLayer(
Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value, Status CreateBroadcastableScalarConstant(OpConverterParams* params, float value,
const nvinfer1::Dims& dims, const nvinfer1::Dims& dims,
const nvinfer1::ITensor** tensor) { const nvinfer1::ITensor** tensor,
const char* dtype_attr_name = "T") {
TFAttrs attrs(params->node_def);
DataType dtype;
if (attrs.count(dtype_attr_name)) {
dtype = attrs.get<DataType>(dtype_attr_name);
} else {
dtype = DT_FLOAT; // Default to FP32.
}
// In order to be broadcastable, the number of dims has to match. // In order to be broadcastable, the number of dims has to match.
nvinfer1::Dims broadcastable_dims(dims); nvinfer1::Dims broadcastable_dims(dims);
for (int i = 0; i < broadcastable_dims.nbDims; i++) { for (int i = 0; i < broadcastable_dims.nbDims; i++) {
broadcastable_dims.d[i] = 1; broadcastable_dims.d[i] = 1;
} }
TRT_ShapedWeights weights = params->weight_store->GetTempWeights( TRT_ShapedWeights weights =
DataType::DT_FLOAT, broadcastable_dims); params->weight_store->GetTempWeights(dtype, broadcastable_dims);
auto weights_ptr = void* raw_ptr = const_cast<void*>(weights.GetValues());
static_cast<float*>(const_cast<void*>(weights.GetValues())); switch (dtype) {
weights_ptr[0] = value; case DataType::DT_FLOAT:
static_cast<float*>(raw_ptr)[0] = value;
break;
case DataType::DT_HALF:
static_cast<Eigen::half*>(raw_ptr)[0] = Eigen::half(value);
break;
default:
return errors::InvalidArgument("Unsupported data type ",
DataTypeString(dtype));
}
*tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims); *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name()); TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
params->converter->ProvideQuantizationRange( params->converter->ProvideQuantizationRange(
@ -662,88 +762,6 @@ string TRT_TensorOrWeights::DebugString() const {
return output; return output;
} }
class TFAttrs {
public:
explicit TFAttrs(const NodeDef& tf_node) {
for (const auto& attr : tf_node.attr()) {
attrs_.insert({attr.first, &attr.second});
}
}
bool count(const string& key) const { return attrs_.count(key); }
AttrValue const* at(const string& key) const {
if (!attrs_.count(key)) {
LOG(FATAL) << "Attribute not found: " << key;
}
return attrs_.at(key);
}
template <typename T>
T get(const string& key) const;
template <typename T>
T get(const string& key, const T& default_value) const {
return attrs_.count(key) ? this->get<T>(key) : default_value;
}
std::vector<string> GetAllAttrKeys() const {
std::vector<string> attr_list;
for (const auto& attr_item : attrs_) {
attr_list.emplace_back(attr_item.first);
}
return attr_list;
}
private:
typedef std::map<string, AttrValue const*> AttrMap;
AttrMap attrs_;
};
template <>
string TFAttrs::get<string>(const string& key) const {
return this->at(key)->s();
}
template <>
std::vector<int64> TFAttrs::get<std::vector<int64>>(const string& key) const {
auto attr = this->at(key)->list().i();
return std::vector<int64>(attr.begin(), attr.end());
}
template <>
std::vector<float> TFAttrs::get<std::vector<float>>(const string& key) const {
auto attr = this->at(key)->list().f();
return std::vector<float>(attr.begin(), attr.end());
}
template <>
nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(const string& key) const {
nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
return trt_dtype;
}
template <>
DataType TFAttrs::get<DataType>(const string& key) const {
return this->at(key)->type();
}
template <>
float TFAttrs::get<float>(const string& key) const {
return this->at(key)->f();
}
template <>
bool TFAttrs::get<bool>(const string& key) const {
return this->at(key)->b();
}
template <>
int64 TFAttrs::get<int64>(const string& key) const {
return this->at(key)->i();
}
// TODO(jie): reorder4 & reorder2 should be merged? // TODO(jie): reorder4 & reorder2 should be merged?
// TODO(aaroey): fix the order of parameters. // TODO(aaroey): fix the order of parameters.
template <typename T> template <typename T>
@ -1435,11 +1453,15 @@ Status CheckInputsWeights(
} }
Status AllowDataTypes(const OpConverterParams& params, Status AllowDataTypes(const OpConverterParams& params,
const std::set<DataType>& allowed_dtypes) { const std::set<DataType>& allowed_dtypes,
const char* dtype_attr_name = "T") {
const auto& node_def = params.node_def; const auto& node_def = params.node_def;
TFAttrs attrs(params.node_def); TFAttrs attrs(node_def);
if (attrs.count("T")) { if (!attrs.count(dtype_attr_name)) {
const auto op_dtype = attrs.get<DataType>("T"); return errors::InvalidArgument("Attribute with name ", dtype_attr_name,
" not found.");
}
const auto op_dtype = attrs.get<DataType>(dtype_attr_name);
if (!allowed_dtypes.count(op_dtype)) { if (!allowed_dtypes.count(op_dtype)) {
// Build string list of allowed types. // Build string list of allowed types.
std::ostringstream ss; std::ostringstream ss;
@ -1452,9 +1474,6 @@ Status AllowDataTypes(const OpConverterParams& params,
", must be one of [", ss.str(), "], at ", ", must be one of [", ss.str(), "], at ",
node_def.name()); node_def.name());
} }
}
// If there is no T attribute, we can't determine the type of the op. We will
// allow it to convert for now.
return Status::OK(); return Status::OK();
} }
@ -3696,7 +3715,8 @@ Status ConvertGather(OpConverterParams* params) {
TF_RETURN_IF_ERROR(CheckInputsWeights( TF_RETURN_IF_ERROR(CheckInputsWeights(
*params, {{"params", false}, {"indices", false}, {"axis", true}})); *params, {{"params", false}, {"indices", false}, {"axis", true}}));
TF_RETURN_IF_ERROR(AllowDataTypes( TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32},
/*dtype_attr_name=*/"Tparams"));
absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>(); absl::Span<const int> axis = inputs.at(2).weights().GetSpan<int>();
if (axis.size() != 1) { if (axis.size() != 1) {
return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ", return errors::InvalidArgument("Axis for GatherV2 must be a scalar, at ",

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
@ -109,13 +110,17 @@ DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) {
} }
NodeDef MakeNodeDef(const string& name, const string& op, NodeDef MakeNodeDef(const string& name, const string& op,
const std::vector<string>& inputs) { const std::vector<string>& inputs,
const std::map<string, AttrValue> attrs = {}) {
NodeDef node_def; NodeDef node_def;
node_def.set_name(name); node_def.set_name(name);
node_def.set_op(op); node_def.set_op(op);
for (const string& input : inputs) { for (const string& input : inputs) {
node_def.add_input(input); node_def.add_input(input);
} }
for (const auto& attr : attrs) {
(*node_def.mutable_attr())[attr.first] = attr.second;
}
return node_def; return node_def;
} }
@ -1094,8 +1099,22 @@ class OpConverterTest : public ::testing::Test {
validator_inputs_.clear(); validator_inputs_.clear();
} }
// TODO(laigd): test fp16 and int8 support. void CheckDataTypeMatches(const DataVec& datas) {
void BuildAndRun(const DataVec& input_data, DataVec* output_data) { for (const auto& data : datas) {
const int input_index = engine_->getBindingIndex(data.name);
ASSERT_NE(-1, input_index);
const nvinfer1::DataType trt_dtype =
engine_->getBindingDataType(input_index);
const DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
ASSERT_EQ(data.tensor.dtype(), tf_dtype)
<< DataTypeString(data.tensor.dtype()) << " vs. "
<< DataTypeString(tf_dtype);
}
}
// TODO(laigd): test fp16 and int8 support for more converters.
void BuildAndRun(const DataVec& input_data, DataVec* output_data,
TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32) {
// Mark the output tensor as TRT engine output. // Mark the output tensor as TRT engine output.
std::vector<Converter::EngineOutputInfo> output_info; std::vector<Converter::EngineOutputInfo> output_info;
for (const auto& data : *output_data) { for (const auto& data : *output_data) {
@ -1105,9 +1124,20 @@ class OpConverterTest : public ::testing::Test {
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
// Build the TRT engine. // Build the TRT engine.
if (precision_mode == TrtPrecisionMode::FP16) {
builder_->setFp16Mode(true);
} else if (precision_mode == TrtPrecisionMode::INT8) {
// Setting FP16 mode as well allows TRT to also consider FP16 kernels and
// use them in situations where they are faster than INT8 or where INT8 is
// not supported for a given layer.
builder_->setFp16Mode(true);
builder_->setInt8Mode(true);
}
ASSERT_EQ(nullptr, engine_.get()); ASSERT_EQ(nullptr, engine_.get());
engine_.reset(builder_->buildCudaEngine(*converter_->network())); engine_.reset(builder_->buildCudaEngine(*converter_->network()));
CHECK_NOTNULL(engine_.get()); CHECK_NOTNULL(engine_.get());
CheckDataTypeMatches(input_data);
CheckDataTypeMatches(*output_data);
// Execute the TRT engine. // Execute the TRT engine.
const int num_bindings = input_data.size() + output_data->size(); const int num_bindings = input_data.size() + output_data->size();
@ -1761,7 +1791,9 @@ void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) {
const DataVec input_data{ const DataVec input_data{
{"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}}; {"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}}; DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
test->BuildAndRun(input_data, &output_data); test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
if (node_def.op() == "Add") { if (node_def.op() == "Add") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(5), CType(10.5))); ElementsAre(CType(5), CType(10.5)));
@ -1942,7 +1974,9 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) {
DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}}; DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
// After broadcasting first input becomes {3, 6, 3, 6} and second input // After broadcasting first input becomes {3, 6, 3, 6} and second input
// becomes {2, 3, 2, 3}. // becomes {2, 3, 2, 3}.
test->BuildAndRun(input_data, &output_data); test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
if (node_def.op() == "Add") { if (node_def.op() == "Add") {
EXPECT_THAT(GetSpanForData<CType>(output_data[0]), EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(5), CType(8), CType(6), CType(9))); ElementsAre(CType(5), CType(8), CType(6), CType(9)));
@ -1974,10 +2008,13 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) {
} }
TEST_F(OpConverterTest, ConvertBinary) { TEST_F(OpConverterTest, ConvertBinary) {
AttrValue dtype;
dtype.set_type(DT_FLOAT);
// Input size doesn't match, should fail. // Input size doesn't match, should fail.
for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) { for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) {
Reset(); Reset();
NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); NodeDef node_def =
MakeNodeDef("my_add", "Add", {num_inputs, "input"}, {{"T", dtype}});
AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT);
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
StrCat("Add got ", std::to_string(num_inputs), StrCat("Add got ", std::to_string(num_inputs),
@ -1987,7 +2024,8 @@ TEST_F(OpConverterTest, ConvertBinary) {
{ {
// Both inputs are weights. // Both inputs are weights.
Reset(); Reset();
NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"}); NodeDef node_def =
MakeNodeDef("my_add", "Add", {"weights1", "weights2"}, {{"T", dtype}});
AddTestWeights<float>("weights1", {1}, {1}); AddTestWeights<float>("weights1", {1}, {1});
AddTestWeights<float>("weights2", {1}, {1}); AddTestWeights<float>("weights2", {1}, {1});
RunValidationAndConversion( RunValidationAndConversion(
@ -2002,15 +2040,12 @@ TEST_F(OpConverterTest, ConvertBinary) {
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this); TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this); TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this); TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
#if 0
// TODO(b/119560144): it doesn't support FP16 constants and the following test
// will fail.
TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this); TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this); TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this); TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this); TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this); TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
#endif
// Test BinaryTensorOpWeight() with channel-wise broadcasting. // Test BinaryTensorOpWeight() with channel-wise broadcasting.
TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this); TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
@ -2192,7 +2227,8 @@ void TestConvertSquare(OpConverterTest* test) {
auto square = ops::Square(s.WithOpName("my_square"), input); auto square = ops::Square(s.WithOpName("my_square"), input);
NodeDef node_def = square.operation.node()->def(); NodeDef node_def = square.operation.node()->def();
test->AddTestTensor("input", {1, 20}); test->AddTestTensor("input", {1, 20}, /*batch_size=*/1,
TfDataTypeToTrt(dtype));
test->RunValidationAndConversion(node_def); test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output; TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output)); TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output));
@ -2202,14 +2238,18 @@ void TestConvertSquare(OpConverterTest* test) {
const int num_inputs = 20; const int num_inputs = 20;
std::vector<CType> inputs(num_inputs); std::vector<CType> inputs(num_inputs);
std::vector<CType> expected_outputs(num_inputs); std::vector<CType> expected_outputs(num_inputs);
for (int i = 0; i < 20; i++) { for (int i = 0; i < num_inputs; ++i) {
const CType value = CType(i - 9); const CType value = CType(i - 9);
inputs[i] = value; inputs[i] = value;
expected_outputs[i] = value * value; expected_outputs[i] = value * value;
} }
const DataVec input_data{{"input", test::AsTensor<CType>(inputs)}}; const DataVec input_data{{"input", test::AsTensor<CType>(inputs)}};
// Engine outputs are converted to FP16 automatically if we set FP16 mode in
// the builder.
DataVec output_data{{"my_square", ConstructTensor<CType>(num_inputs)}}; DataVec output_data{{"my_square", ConstructTensor<CType>(num_inputs)}};
test->BuildAndRun(input_data, &output_data); test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0])); ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0]));
} }
@ -2237,9 +2277,7 @@ TEST_F(OpConverterTest, ConvertSquare) {
// OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't
// test DT_INT32 type here. // test DT_INT32 type here.
TestConvertSquare<DT_FLOAT>(this); TestConvertSquare<DT_FLOAT>(this);
// TODO(tmorris): Looks like there may be a bug with this layer for FP16 TestConvertSquare<DT_HALF>(this);
// inputs. Disabling for now.
// TestConvertSquare<DT_HALF>(this);
} }
TEST_F(OpConverterTest, ConvertActivation) { TEST_F(OpConverterTest, ConvertActivation) {
@ -2269,10 +2307,10 @@ TEST_F(OpConverterTest, ConvertActivation) {
Scope s = Scope::NewRootScope(); Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
if (op_name == "LeakyRelu") { if (op_name == "LeakyRelu") {
// LeakyRelu does not have a C++ API auto act =
NodeDef node_def = MakeNodeDef("my_act", "LeakyRelu", {"input"}); ops::internal::LeakyRelu(s.WithOpName("my_act"), input,
(*node_def.mutable_attr())["alpha"].set_f(kAlpha); ops::internal::LeakyRelu::Alpha(kAlpha));
return node_def; return act.operation.node()->def();
} else if (op_name == "Relu") { } else if (op_name == "Relu") {
auto act = ops::Relu(s.WithOpName("my_act"), input); auto act = ops::Relu(s.WithOpName("my_act"), input);
return act.operation.node()->def(); return act.operation.node()->def();