Class tgn::TGNImpl
The core Temporal Graph Network module. Manages node memory state and temporal neighborhood aggregation.
#include <tgn.h>
Inherits the following classes: torch::nn::Module
Public Functions
| Type | Name |
|---|---|
| TGNImpl (const TGNConfig & cfg, const std::shared_ptr< tguf::TGStore > & store) |
|
| auto | detach_memory () Detaches memory from the computational graph to truncate backprop (BPTT). |
| auto | device () const Get the torch::Device used by the module. |
| auto | forward (const Ts &... inputs) Variadic forward pass. |
| auto | reset_state () Zeros out all node memory and resets last-update timestamps. |
| auto | update_state (const torch::Tensor & src, const torch::Tensor & dst, const torch::Tensor & time, const torch::Tensor & msg) Updates internal memory given a batch of true edge events. |
| ~TGNImpl () |
Public Functions Documentation
function TGNImpl
tgn::TGNImpl::TGNImpl (
const TGNConfig & cfg,
const std::shared_ptr< tguf::TGStore > & store
)
function detach_memory
Detaches memory from the computational graph to truncate backprop (BPTT).
auto tgn::TGNImpl::detach_memory ()
function device
Get the torch::Device used by the module.
auto tgn::TGNImpl::device () const
function forward
Variadic forward pass.
template<typename... Ts>
inline auto tgn::TGNImpl::forward (
const Ts &... inputs
)
Parameters:
inputsTensors of node IDs to compute embeddings for.
Returns:
A tuple of embeddings [B, embedding_dim] in same order as inputs.
function reset_state
Zeros out all node memory and resets last-update timestamps.
auto tgn::TGNImpl::reset_state ()
function update_state
Updates internal memory given a batch of true edge events.
auto tgn::TGNImpl::update_state (
const torch::Tensor & src,
const torch::Tensor & dst,
const torch::Tensor & time,
const torch::Tensor & msg
)
function ~TGNImpl
tgn::TGNImpl::~TGNImpl ()
The documentation for this class was generated from the following file include/tgn.h