Skip to content

File tgn.h

File List > include > tgn.h

Go to the documentation of this file

#pragma once

#include <torch/nn/module.h>
#include <torch/types.h>

#include <cstddef>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "tguf.h"

namespace tgn {

struct TGNConfig {
  torch::Device device = torch::kCPU;  
  std::size_t embedding_dim = 100;     
  std::size_t memory_dim = 100;        
  std::size_t time_dim = 100;          
  std::size_t num_heads = 2;  
  std::size_t num_nbrs = 10;  
  float dropout = 0.1;        
};

class TGNImpl : public torch::nn::Module {
 public:
  TGNImpl(const TGNConfig& cfg, const std::shared_ptr<tguf::TGStore>& store);
  ~TGNImpl();

  auto detach_memory() -> void;

  auto reset_state() -> void;

  auto update_state(const torch::Tensor& src, const torch::Tensor& dst,
                    const torch::Tensor& time, const torch::Tensor& msg)
      -> void;

  auto device() const -> torch::Device;

  template <typename... Ts>
  auto forward(const Ts&... inputs) {
    if constexpr (sizeof...(inputs) == 0) {
      throw std::invalid_argument(
          "TGN::forward requires at least one input ID tensor.");
    }
    std::vector<torch::Tensor> input_list = {inputs...};
    auto results = forward_internal(input_list);
    return vec_to_tuple<sizeof...(inputs)>(
        results, std::make_index_sequence<sizeof...(inputs)>{});
  }

 private:
  auto forward_internal(const std::vector<torch::Tensor>& input_list)
      -> std::vector<torch::Tensor>;

  template <std::size_t N, std::size_t... Is>
  auto vec_to_tuple(const std::vector<torch::Tensor>& v,
                    std::index_sequence<Is...>) {
    return std::make_tuple(v[Is]...);
  }

  struct Impl;
  std::unique_ptr<Impl> impl_;
};

TORCH_MODULE(TGN);

}  // namespace tgn