File tgn.h
Go to the documentation of this file
#pragma once
#include <torch/nn/module.h>
#include <torch/types.h>
#include <cstddef>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "tguf.h"
namespace tgn {
struct TGNConfig {
torch::Device device = torch::kCPU;
std::size_t embedding_dim = 100;
std::size_t memory_dim = 100;
std::size_t time_dim = 100;
std::size_t num_heads = 2;
std::size_t num_nbrs = 10;
float dropout = 0.1;
};
class TGNImpl : public torch::nn::Module {
public:
TGNImpl(const TGNConfig& cfg, const std::shared_ptr<tguf::TGStore>& store);
~TGNImpl();
auto detach_memory() -> void;
auto reset_state() -> void;
auto update_state(const torch::Tensor& src, const torch::Tensor& dst,
const torch::Tensor& time, const torch::Tensor& msg)
-> void;
auto device() const -> torch::Device;
template <typename... Ts>
auto forward(const Ts&... inputs) {
if constexpr (sizeof...(inputs) == 0) {
throw std::invalid_argument(
"TGN::forward requires at least one input ID tensor.");
}
std::vector<torch::Tensor> input_list = {inputs...};
auto results = forward_internal(input_list);
return vec_to_tuple<sizeof...(inputs)>(
results, std::make_index_sequence<sizeof...(inputs)>{});
}
private:
auto forward_internal(const std::vector<torch::Tensor>& input_list)
-> std::vector<torch::Tensor>;
template <std::size_t N, std::size_t... Is>
auto vec_to_tuple(const std::vector<torch::Tensor>& v,
std::index_sequence<Is...>) {
return std::make_tuple(v[Is]...);
}
struct Impl;
std::unique_ptr<Impl> impl_;
};
TORCH_MODULE(TGN);
} // namespace tgn