#include "tfmodelstate.h"

#include "workspace_status.h"

using namespace tensorflow;
using std::vector;

TFModelState::TFModelState()
  : ModelState()
  , mmap_env_(nullptr)
  , session_(nullptr)
{
}

TFModelState::~TFModelState()
{
  if (session_) {
    Status status = session_->Close();
    if (!status.ok()) {
      std::cerr << "Error closing TensorFlow session: " << status << std::endl;
    }
  }
}

int
TFModelState::init(const char* model_path)
{
  int err = ModelState::init(model_path);
  if (err != STT_ERR_OK) {
    return err;
  }

  Status status;
  SessionOptions options;

  mmap_env_.reset(new MemmappedEnv(Env::Default()));

  bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
  if (!is_mmap) {
    std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
  } else {
    status = mmap_env_->InitializeFromFile(model_path);
    if (!status.ok()) {
      std::cerr << status << std::endl;
      return STT_ERR_FAIL_INIT_MMAP;
    }

    options.config.mutable_graph_options()
      ->mutable_optimizer_options()
      ->set_opt_level(::OptimizerOptions::L0);
    options.env = mmap_env_.get();
  }

  Session* session;
  status = NewSession(options, &session);
  if (!status.ok()) {
    std::cerr << status << std::endl;
    return STT_ERR_FAIL_INIT_SESS;
  }
  session_.reset(session);

  if (is_mmap) {
    status = ReadBinaryProto(mmap_env_.get(),
                             MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
                             &graph_def_);
  } else {
    status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
  }
  if (!status.ok()) {
    std::cerr << status << std::endl;
    return STT_ERR_FAIL_READ_PROTOBUF;
  }

  status = session_->Create(graph_def_);
  if (!status.ok()) {
    std::cerr << status << std::endl;
    return STT_ERR_FAIL_CREATE_SESS;
  }

  std::vector<tensorflow::Tensor> version_output;
  status = session_->Run({}, {
    "metadata_version"
  }, {}, &version_output);
  if (!status.ok()) {
    std::cerr << "Unable to fetch graph version: " << status << std::endl;
    return STT_ERR_MODEL_INCOMPATIBLE;
  }

  int graph_version = version_output[0].scalar<int>()();
  if (graph_version < ds_graph_version()) {
    std::cerr << "Specified model file version (" << graph_version << ") is "
              << "incompatible with minimum version supported by this client ("
              << ds_graph_version() << "). See "
              << "https://stt.readthedocs.io/en/latest/USING.html#model-compatibility "
              << "for more information" << std::endl;
    return STT_ERR_MODEL_INCOMPATIBLE;
  }

  std::vector<tensorflow::Tensor> metadata_outputs;
  status = session_->Run({}, {
    "metadata_sample_rate",
    "metadata_feature_win_len",
    "metadata_feature_win_step",
    "metadata_beam_width",
    "metadata_alphabet",
  }, {}, &metadata_outputs);
  if (!status.ok()) {
    std::cout << "Unable to fetch metadata: " << status << std::endl;
    return STT_ERR_MODEL_INCOMPATIBLE;
  }

  sample_rate_ = metadata_outputs[0].scalar<int>()();
  int win_len_ms = metadata_outputs[1].scalar<int>()();
  int win_step_ms = metadata_outputs[2].scalar<int>()();
  audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
  audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
  int beam_width = metadata_outputs[3].scalar<int>()();
  beam_width_ = (unsigned int)(beam_width);

  string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
  err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
  if (err != 0) {
    return STT_ERR_INVALID_ALPHABET;
  }

  assert(sample_rate_ > 0);
  assert(audio_win_len_ > 0);
  assert(audio_win_step_ > 0);
  assert(beam_width_ > 0);
  assert(alphabet_.GetSize() > 0);

  for (int i = 0; i < graph_def_.node_size(); ++i) {
    NodeDef node = graph_def_.node(i);
    if (node.name() == "input_node") {
      const auto& shape = node.attr().at("shape").shape();
      n_steps_ = shape.dim(1).size();
      n_context_ = (shape.dim(2).size()-1)/2;
      n_features_ = shape.dim(3).size();
      mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
    } else if (node.name() == "previous_state_c") {
      const auto& shape = node.attr().at("shape").shape();
      state_size_ = shape.dim(1).size();
    } else if (node.name() == "logits_shape") {
      Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
      if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
        continue;
      }

      int final_dim_size = logits_shape.vec<int>()(2) - 1;
      if (final_dim_size != alphabet_.GetSize()) {
        std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
                  << "has size " << alphabet_.GetSize()
                  << ", but model has " << final_dim_size
                  << " classes in its output. Make sure you're passing an alphabet "
                  << "file with the same size as the one used for training."
                  << std::endl;
        return STT_ERR_INVALID_ALPHABET;
      }
    }
  }

  if (n_context_ == -1 || n_features_ == -1) {
    std::cerr << "Error: Could not infer input shape from model file. "
              << "Make sure input_node is a 4D tensor with shape "
              << "[batch_size=1, time, window_size, n_features]."
              << std::endl;
    return STT_ERR_INVALID_SHAPE;
  }

  return STT_ERR_OK;
}

Tensor
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
{
  Tensor ret(DT_FLOAT, shape);
  auto ret_mapped = ret.flat<float>();
  int i;
  for (i = 0; i < vec.size(); ++i) {
    ret_mapped(i) = vec[i];
  }
  for (; i < shape.num_elements(); ++i) {
    ret_mapped(i) = 0.f;
  }
  return ret;
}

void
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
{
  auto tensor_mapped = tensor.flat<float>();
  if (num_elements == -1) {
    num_elements = tensor.shape().num_elements();
  }
  for (int i = 0; i < num_elements; ++i) {
    vec.push_back(tensor_mapped(i));
  }
}

void
TFModelState::infer(const std::vector<float>& mfcc,
                    unsigned int n_frames,
                    const std::vector<float>& previous_state_c,
                    const std::vector<float>& previous_state_h,
                    vector<float>& logits_output,
                    vector<float>& state_c_output,
                    vector<float>& state_h_output)
{
  const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank

  Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
  Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
  Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));

  Tensor input_lengths(DT_INT32, TensorShape({1}));
  input_lengths.scalar<int>()() = n_frames;

  vector<Tensor> outputs;
  Status status = session_->Run(
    {
     {"input_node", input},
     {"input_lengths", input_lengths},
     {"previous_state_c", previous_state_c_t},
     {"previous_state_h", previous_state_h_t}
    },
    {"logits", "new_state_c", "new_state_h"},
    {},
    &outputs);

  if (!status.ok()) {
    std::cerr << "Error running session: " << status << "\n";
    return;
  }

  copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);

  state_c_output.clear();
  state_c_output.reserve(state_size_);
  copy_tensor_to_vector(outputs[1], state_c_output);

  state_h_output.clear();
  state_h_output.reserve(state_size_);
  copy_tensor_to_vector(outputs[2], state_h_output);
}

void
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
{
  Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));

  vector<Tensor> outputs;
  Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);

  if (!status.ok()) {
    std::cerr << "Error running session: " << status << "\n";
    return;
  }

  // The feature computation graph is hardcoded to one audio length for now
  const int n_windows = 1;
  assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
  copy_tensor_to_vector(outputs[0], mfcc_output);
}