Skip to content

Class tgn::TGNImpl

ClassList > tgn > TGNImpl

The core Temporal Graph Network module. Manages node memory state and temporal neighborhood aggregation.

  • #include <tgn.h>

Inherits the following classes: torch::nn::Module

Public Functions

Type Name
TGNImpl (const TGNConfig & cfg, const std::shared_ptr< tguf::TGStore > & store)
auto detach_memory ()
Detaches memory from the computational graph to truncate backprop (BPTT).
auto device () const
Get the torch::Device used by the module.
auto forward (const Ts &... inputs)
Variadic forward pass.
auto reset_state ()
Zeros out all node memory and resets last-update timestamps.
auto update_state (const torch::Tensor & src, const torch::Tensor & dst, const torch::Tensor & time, const torch::Tensor & msg)
Updates internal memory given a batch of true edge events.
~TGNImpl ()

Public Functions Documentation

function TGNImpl

tgn::TGNImpl::TGNImpl (
    const TGNConfig & cfg,
    const std::shared_ptr< tguf::TGStore > & store
) 

function detach_memory

Detaches memory from the computational graph to truncate backprop (BPTT).

auto tgn::TGNImpl::detach_memory () 

function device

Get the torch::Device used by the module.

auto tgn::TGNImpl::device () const

function forward

Variadic forward pass.

template<typename... Ts>
inline auto tgn::TGNImpl::forward (
    const Ts &... inputs
) 

Parameters:

  • inputs Tensors of node IDs to compute embeddings for.

Returns:

A tuple of embeddings [B, embedding_dim] in same order as inputs.


function reset_state

Zeros out all node memory and resets last-update timestamps.

auto tgn::TGNImpl::reset_state () 

function update_state

Updates internal memory given a batch of true edge events.

auto tgn::TGNImpl::update_state (
    const torch::Tensor & src,
    const torch::Tensor & dst,
    const torch::Tensor & time,
    const torch::Tensor & msg
) 

function ~TGNImpl

tgn::TGNImpl::~TGNImpl () 


The documentation for this class was generated from the following file include/tgn.h