File tguf.h
Go to the documentation of this file
#pragma once
#include <torch/types.h>
#include <cstddef>
#include <memory>
#include <optional>
#include <string>
namespace tguf {
struct Batch {
torch::Tensor src;
torch::Tensor dst;
torch::Tensor time;
torch::Tensor msg;
std::optional<torch::Tensor>
neg_dst;
};
struct LabelEvent {
torch::Tensor n_id;
torch::Tensor target;
};
struct TGUFSchema {
std::string path;
std::size_t edge_capacity;
std::size_t label_capacity;
std::size_t node_feat_capacity;
std::size_t msg_dim;
std::size_t label_dim;
std::size_t node_feat_dim;
std::size_t negatives_start_e_id;
std::size_t negatives_per_edge;
std::optional<std::size_t> val_start = std::nullopt;
std::optional<std::size_t> test_start = std::nullopt;
};
class TGUFBuilder {
public:
explicit TGUFBuilder(const TGUFSchema& schema);
~TGUFBuilder();
auto append_edges(const Batch& batch) const -> void;
auto append_labels(const torch::Tensor& n_id, const torch::Tensor& time,
const torch::Tensor& target) const -> void;
auto append_node_feats(const torch::Tensor& n_id,
const torch::Tensor& node_feat) const -> void;
auto finalize() -> void;
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
class TGStore {
public:
enum class NegStrategy {
None,
Random,
PreComputed,
};
struct IndexRange {
IndexRange() = default;
IndexRange(std::size_t s, std::size_t e) : start_(s), end_(e) {
if (end_ < start_) {
throw std::out_of_range("Invalid range");
}
}
[[nodiscard]] auto start() const -> std::size_t { return start_; }
[[nodiscard]] auto end() const -> std::size_t { return end_; }
[[nodiscard]] auto size() const -> std::size_t { return end_ - start_; }
std::size_t start_{0};
std::size_t end_{0};
};
virtual ~TGStore() = default;
[[nodiscard]] static auto from_memory(
const Batch& edges,
const std::optional<torch::Tensor>& node_feats = std::nullopt,
const std::optional<torch::Tensor>& label_n_id = std::nullopt,
const std::optional<torch::Tensor>& label_time = std::nullopt,
const std::optional<torch::Tensor>& label_target = std::nullopt,
std::optional<std::size_t> val_start = std::nullopt,
std::optional<std::size_t> test_start = std::nullopt)
-> std::unique_ptr<TGStore>;
[[nodiscard]] static auto from_tguf(
const std::string& path,
std::optional<std::size_t> val_start = std::nullopt,
std::optional<std::size_t> test_start = std::nullopt)
-> std::unique_ptr<TGStore>;
[[nodiscard]] virtual auto edge_count() const -> std::size_t = 0;
[[nodiscard]] virtual auto node_count() const -> std::size_t = 0;
[[nodiscard]] virtual auto label_count() const -> std::size_t = 0;
[[nodiscard]] virtual auto msg_dim() const -> std::size_t = 0;
[[nodiscard]] virtual auto label_dim() const -> std::size_t = 0;
[[nodiscard]] virtual auto node_feat_dim() const -> std::size_t = 0;
[[nodiscard]] virtual auto train_split() const -> IndexRange = 0;
[[nodiscard]] virtual auto val_split() const -> IndexRange = 0;
[[nodiscard]] virtual auto test_split() const -> IndexRange = 0;
[[nodiscard]] virtual auto train_label_split() const -> IndexRange = 0;
[[nodiscard]] virtual auto val_label_split() const -> IndexRange = 0;
[[nodiscard]] virtual auto test_label_split() const -> IndexRange = 0;
[[nodiscard]] virtual auto get_batch(std::size_t start, std::size_t size,
NegStrategy strategy = NegStrategy::None,
torch::Device device = torch::kCPU) const
-> Batch = 0;
[[nodiscard]] virtual auto gather_timestamps(const torch::Tensor& e_id) const
-> torch::Tensor = 0;
[[nodiscard]] virtual auto gather_msgs(const torch::Tensor& e_id) const
-> torch::Tensor = 0;
[[nodiscard]] virtual auto gather_node_feats(const torch::Tensor& n_id) const
-> torch::Tensor = 0;
[[nodiscard]] virtual auto get_edge_cutoff_for_label_event(
std::size_t l_id) const -> std::size_t = 0;
[[nodiscard]] virtual auto get_label_event(
std::size_t l_id, torch::Device device = torch::kCPU) const
-> LabelEvent = 0;
};
} // namespace tguf