gnn_tracking_hpo.load
#
Everything related to data loaders and data sets.
Module Contents#
Functions#
|
Load graphs for training, testing, and validation from one directory. |
|
Load graphs for training and validation from separate directories. |
|
Get data loaders |
- gnn_tracking_hpo.load.get_graphs_split(*, train_size: int, val_size: int, input_dirs: list[os.PathLike] | list[str], sector: int | None = None, test=False) dict[str, list] #
Load graphs for training, testing, and validation from one directory.
- Parameters:
train_size – Number of graphs to use for training
val_size – Number of graphs to use for validation
input_dirs – Directory containing the graphs
sector – Only load specific sector
test –
- Returns:
Training and validation graphs as dictionary
- gnn_tracking_hpo.load.get_graphs_separate(*, train_size: int, val_size: int, train_dirs: list[str], val_dirs: list[str], sector: int | None = None, test=False) dict[str, list] #
Load graphs for training and validation from separate directories.
- Parameters:
train_size – Number of graphs to use for training
val_size – Number of graphs to use for validation
train_dirs – Directory containing the training graphs
val_dirs – Directory containing the test graphs
sector – Only load specific sector
test –
- Returns:
Training and validation graphs as dictionary
- gnn_tracking_hpo.load.get_loaders(graph_dct: dict[str, list], batch_size=1, val_batch_size=1, test=False) dict[str, torch_geometric.loader.DataLoader] #
Get data loaders
- Parameters:
graph_dct –
batch_size –
test –
- Returns:
Dictionary of data loaders