Added tf.cos to TFLite Ops.
PiperOrigin-RevId: 233643966
This commit is contained in:
parent
5f5088767d
commit
bb09f1e95e
@ -238,6 +238,7 @@ def generated_test_models():
|
||||
"conv2d_transpose",
|
||||
"conv_with_shared_weights",
|
||||
"conv_to_depthwiseconv_with_shared_weights",
|
||||
"cos",
|
||||
"depthwiseconv",
|
||||
"div",
|
||||
"equal",
|
||||
|
@ -133,6 +133,7 @@ typedef enum {
|
||||
kTfLiteBuiltinReverseV2 = 105,
|
||||
kTfLiteBuiltinAddN = 106,
|
||||
kTfLiteBuiltinGatherNd = 107,
|
||||
kTfLiteBuiltinCos = 108,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -680,6 +680,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
// ok for now, since there is no call implementation either.
|
||||
case BuiltinOperator_CALL:
|
||||
case BuiltinOperator_CONCAT_EMBEDDINGS:
|
||||
case BuiltinOperator_COS:
|
||||
case BuiltinOperator_CUSTOM:
|
||||
case BuiltinOperator_DEQUANTIZE:
|
||||
case BuiltinOperator_EMBEDDING_LOOKUP:
|
||||
|
@ -83,6 +83,10 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, std::sin);
|
||||
}
|
||||
|
||||
TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, std::cos);
|
||||
}
|
||||
|
||||
TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalNumeric(context, node, std::log);
|
||||
}
|
||||
@ -122,6 +126,14 @@ TfLiteRegistration* Register_SIN() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_COS() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
|
||||
elementwise::CosEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LOG() {
|
||||
static TfLiteRegistration r = {
|
||||
/*init=*/nullptr, /*free=*/nullptr,
|
||||
|
@ -65,6 +65,15 @@ TEST(ElementWise, Sin) {
|
||||
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
|
||||
}
|
||||
|
||||
TEST(ElementWise, Cos) {
|
||||
ElementWiseOpFloatModel m(BuiltinOperator_COS, {1, 1, 4, 1});
|
||||
m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.ExtractVector<float>(m.output()),
|
||||
ElementsAreArray(ArrayFloatNear({1, -1, -1, 0.54030})));
|
||||
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
|
||||
}
|
||||
|
||||
TEST(ElementWise, Log) {
|
||||
ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
|
||||
m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
|
||||
|
@ -104,6 +104,7 @@ TfLiteRegistration* Register_REDUCE_ANY();
|
||||
TfLiteRegistration* Register_SELECT();
|
||||
TfLiteRegistration* Register_SLICE();
|
||||
TfLiteRegistration* Register_SIN();
|
||||
TfLiteRegistration* Register_COS();
|
||||
TfLiteRegistration* Register_TRANSPOSE_CONV();
|
||||
TfLiteRegistration* Register_EXPAND_DIMS();
|
||||
TfLiteRegistration* Register_SPARSE_TO_DENSE();
|
||||
@ -304,6 +305,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE(), /* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
|
||||
AddBuiltin(BuiltinOperator_COS, Register_COS());
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
|
||||
AddBuiltin(BuiltinOperator_TILE, Register_TILE());
|
||||
AddBuiltin(BuiltinOperator_SUM, Register_SUM());
|
||||
|
@ -610,6 +610,7 @@ TfLiteStatus AddOpsAndParams(
|
||||
case tflite::BuiltinOperator_SPLIT:
|
||||
case tflite::BuiltinOperator_STRIDED_SLICE:
|
||||
case tflite::BuiltinOperator_EXP:
|
||||
case tflite::BuiltinOperator_COS:
|
||||
case tflite::BuiltinOperator_LOG_SOFTMAX:
|
||||
case tflite::BuiltinOperator_DEQUANTIZE:
|
||||
case tflite::BuiltinOperator_DELEGATE:
|
||||
|
@ -221,6 +221,7 @@ enum BuiltinOperator : byte {
|
||||
REVERSE_V2 = 105,
|
||||
ADD_N = 106,
|
||||
GATHER_ND = 107,
|
||||
COS = 108,
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -308,6 +309,7 @@ union BuiltinOptions {
|
||||
ReverseV2Options,
|
||||
AddNOptions,
|
||||
GatherNdOptions,
|
||||
CosOptions,
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -551,6 +553,9 @@ table TransposeOptions {
|
||||
table ExpOptions {
|
||||
}
|
||||
|
||||
table CosOptions {
|
||||
}
|
||||
|
||||
table ReducerOptions {
|
||||
keep_dims: bool;
|
||||
}
|
||||
|
@ -139,6 +139,9 @@ struct TransposeOptionsT;
|
||||
struct ExpOptions;
|
||||
struct ExpOptionsT;
|
||||
|
||||
struct CosOptions;
|
||||
struct CosOptionsT;
|
||||
|
||||
struct ReducerOptions;
|
||||
struct ReducerOptionsT;
|
||||
|
||||
@ -537,11 +540,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_REVERSE_V2 = 105,
|
||||
BuiltinOperator_ADD_N = 106,
|
||||
BuiltinOperator_GATHER_ND = 107,
|
||||
BuiltinOperator_COS = 108,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_GATHER_ND
|
||||
BuiltinOperator_MAX = BuiltinOperator_COS
|
||||
};
|
||||
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[107] {
|
||||
inline const BuiltinOperator (&EnumValuesBuiltinOperator())[108] {
|
||||
static const BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -649,7 +653,8 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[107] {
|
||||
BuiltinOperator_CEIL,
|
||||
BuiltinOperator_REVERSE_V2,
|
||||
BuiltinOperator_ADD_N,
|
||||
BuiltinOperator_GATHER_ND
|
||||
BuiltinOperator_GATHER_ND,
|
||||
BuiltinOperator_COS
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -764,6 +769,7 @@ inline const char * const *EnumNamesBuiltinOperator() {
|
||||
"REVERSE_V2",
|
||||
"ADD_N",
|
||||
"GATHER_ND",
|
||||
"COS",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -859,11 +865,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_ReverseV2Options = 81,
|
||||
BuiltinOptions_AddNOptions = 82,
|
||||
BuiltinOptions_GatherNdOptions = 83,
|
||||
BuiltinOptions_CosOptions = 84,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_GatherNdOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_CosOptions
|
||||
};
|
||||
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[84] {
|
||||
inline const BuiltinOptions (&EnumValuesBuiltinOptions())[85] {
|
||||
static const BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -948,7 +955,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[84] {
|
||||
BuiltinOptions_UniqueOptions,
|
||||
BuiltinOptions_ReverseV2Options,
|
||||
BuiltinOptions_AddNOptions,
|
||||
BuiltinOptions_GatherNdOptions
|
||||
BuiltinOptions_GatherNdOptions,
|
||||
BuiltinOptions_CosOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -1039,6 +1047,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
|
||||
"ReverseV2Options",
|
||||
"AddNOptions",
|
||||
"GatherNdOptions",
|
||||
"CosOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -1385,6 +1394,10 @@ template<> struct BuiltinOptionsTraits<GatherNdOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<CosOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_CosOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -2080,6 +2093,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_GatherNdOptions ?
|
||||
reinterpret_cast<const GatherNdOptionsT *>(value) : nullptr;
|
||||
}
|
||||
CosOptionsT *AsCosOptions() {
|
||||
return type == BuiltinOptions_CosOptions ?
|
||||
reinterpret_cast<CosOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const CosOptionsT *AsCosOptions() const {
|
||||
return type == BuiltinOptions_CosOptions ?
|
||||
reinterpret_cast<const CosOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -5012,6 +5033,46 @@ inline flatbuffers::Offset<ExpOptions> CreateExpOptions(
|
||||
|
||||
flatbuffers::Offset<ExpOptions> CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct CosOptionsT : public flatbuffers::NativeTable {
|
||||
typedef CosOptions TableType;
|
||||
CosOptionsT() {
|
||||
}
|
||||
};
|
||||
|
||||
struct CosOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef CosOptionsT NativeTableType;
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
CosOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(CosOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<CosOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct CosOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
explicit CosOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
CosOptionsBuilder &operator=(const CosOptionsBuilder &);
|
||||
flatbuffers::Offset<CosOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<CosOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<CosOptions> CreateCosOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
CosOptionsBuilder builder_(_fbb);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<CosOptions> CreateCosOptions(flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct ReducerOptionsT : public flatbuffers::NativeTable {
|
||||
typedef ReducerOptions TableType;
|
||||
bool keep_dims;
|
||||
@ -7678,6 +7739,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const GatherNdOptions *builtin_options_as_GatherNdOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_GatherNdOptions ? static_cast<const GatherNdOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const CosOptions *builtin_options_as_CosOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_CosOptions ? static_cast<const CosOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -8041,6 +8105,10 @@ template<> inline const GatherNdOptions *Operator::builtin_options_as<GatherNdOp
|
||||
return builtin_options_as_GatherNdOptions();
|
||||
}
|
||||
|
||||
template<> inline const CosOptions *Operator::builtin_options_as<CosOptions>() const {
|
||||
return builtin_options_as_CosOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -9586,6 +9654,29 @@ inline flatbuffers::Offset<ExpOptions> CreateExpOptions(flatbuffers::FlatBufferB
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline CosOptionsT *CosOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new CosOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void CosOptions::UnPackTo(CosOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<CosOptions> CosOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateCosOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<CosOptions> CreateCosOptions(flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CosOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
return tflite::CreateCosOptions(
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline ReducerOptionsT *ReducerOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new ReducerOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -11347,6 +11438,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const GatherNdOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_CosOptions: {
|
||||
auto ptr = reinterpret_cast<const CosOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
@ -11697,6 +11792,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const GatherNdOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_CosOptions: {
|
||||
auto ptr = reinterpret_cast<const CosOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -12035,6 +12134,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const GatherNdOptionsT *>(value);
|
||||
return CreateGatherNdOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_CosOptions: {
|
||||
auto ptr = reinterpret_cast<const CosOptionsT *>(value);
|
||||
return CreateCosOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -12373,6 +12476,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new GatherNdOptionsT(*reinterpret_cast<GatherNdOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_CosOptions: {
|
||||
value = new CosOptionsT(*reinterpret_cast<CosOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -12795,6 +12902,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_CosOptions: {
|
||||
auto ptr = reinterpret_cast<CosOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -1111,6 +1111,34 @@ def make_exp_tests(zip_path):
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_cos_tests(zip_path):
|
||||
"""Make a set of tests to do cos."""
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32],
|
||||
"input_shape": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the cos op testing graph."""
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
|
||||
out = tf.cos(input_tensor)
|
||||
return [input_tensor], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
values = [
|
||||
create_tensor_data(parameters["input_dtype"], parameters["input_shape"],
|
||||
min_value=-np.pi, max_value=np.pi)
|
||||
]
|
||||
return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_log_softmax_tests(zip_path):
|
||||
"""Make a set of tests to do log_softmax."""
|
||||
|
||||
|
@ -2057,6 +2057,7 @@ void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
||||
case OperatorType::kCeil:
|
||||
case OperatorType::kExp:
|
||||
case OperatorType::kSin:
|
||||
case OperatorType::kCos:
|
||||
case OperatorType::kLogicalAnd:
|
||||
case OperatorType::kLogicalNot:
|
||||
case OperatorType::kLogicalOr:
|
||||
|
@ -2396,6 +2396,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"Const", ConvertConstOperator},
|
||||
{"Conv2D", ConvertConvOperator},
|
||||
{"Conv2DBackpropInput", ConvertTransposeConvOperator},
|
||||
{"Cos", ConvertSimpleOperator<CosOperator, 1, 1>},
|
||||
{"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
|
||||
{"DepthToSpace", ConvertDepthToSpaceOperator},
|
||||
{"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
|
||||
|
@ -45,6 +45,7 @@ enum class OperatorType : uint8 {
|
||||
kCeil,
|
||||
kConv,
|
||||
kConcatenation,
|
||||
kCos,
|
||||
kDepthwiseConv,
|
||||
kDepthToSpace,
|
||||
kSpaceToDepth,
|
||||
@ -1166,6 +1167,17 @@ struct ExpOperator : Operator {
|
||||
ExpOperator() : Operator(OperatorType::kExp) {}
|
||||
};
|
||||
|
||||
// Given a tensor input, this operation calculates element-wise exponential
|
||||
// (y = cos(x)).
|
||||
//
|
||||
// Inputs:
|
||||
// inputs[0]: required: input tensor
|
||||
//
|
||||
// TensorFlow equivalent: Cos
|
||||
struct CosOperator : Operator {
|
||||
CosOperator() : Operator(OperatorType::kCos) {}
|
||||
};
|
||||
|
||||
// Given a tensor input, this operation inserts a dimension of 1 at the
|
||||
// dimension index axis of input's shape. The dimension index axis starts at
|
||||
// zero; if you specify a negative number for axis it is counted backward from
|
||||
|
@ -2343,6 +2343,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
MakeUnique<SimpleOperator<TanhOperator>>("TANH", OperatorType::kTanh));
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<ExpOperator>>("EXP", OperatorType::kExp));
|
||||
ops.push_back(
|
||||
MakeUnique<SimpleOperator<CosOperator>>("COS", OperatorType::kCos));
|
||||
ops.push_back(MakeUnique<SimpleOperator<LogSoftmaxOperator>>(
|
||||
"LOG_SOFTMAX", OperatorType::kLogSoftmax));
|
||||
ops.push_back(MakeUnique<Maximum>()); // Element-wise Maximum
|
||||
|
@ -119,6 +119,7 @@ TEST_F(OperatorTest, SimpleOperators) {
|
||||
CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
|
||||
CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
|
||||
CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
|
||||
CheckSimpleOperator<CosOperator>("COS", OperatorType::kCos);
|
||||
CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
|
||||
OperatorType::kLogSoftmax);
|
||||
CheckSimpleOperator<TensorFlowMaximumOperator>(
|
||||
|
@ -422,6 +422,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(Unique)
|
||||
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
|
||||
HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Cos)
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled op type";
|
||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||
|
Loading…
x
Reference in New Issue
Block a user