gnn_tracking_hpo.load#

Everything related to data loaders and data sets.

Module Contents#

Functions#

get_graphs_split(→ dict[str, list])

Load graphs for training, testing, and validation from one directory.

get_graphs_separate(→ dict[str, list])

Load graphs for training and validation from separate directories.

get_loaders(→ dict[str, torch_geometric.loader.DataLoader])

Get data loaders

gnn_tracking_hpo.load.get_graphs_split(*, train_size: int, val_size: int, input_dirs: list[os.PathLike] | list[str], sector: int | None = None, test=False) dict[str, list]#

Load graphs for training, testing, and validation from one directory.

Parameters:
  • train_size – Number of graphs to use for training

  • val_size – Number of graphs to use for validation

  • input_dirs – Directory containing the graphs

  • sector – Only load specific sector

  • test

Returns:

Training and validation graphs as dictionary

gnn_tracking_hpo.load.get_graphs_separate(*, train_size: int, val_size: int, train_dirs: list[str], val_dirs: list[str], sector: int | None = None, test=False) dict[str, list]#

Load graphs for training and validation from separate directories.

Parameters:
  • train_size – Number of graphs to use for training

  • val_size – Number of graphs to use for validation

  • train_dirs – Directory containing the training graphs

  • val_dirs – Directory containing the test graphs

  • sector – Only load specific sector

  • test

Returns:

Training and validation graphs as dictionary

gnn_tracking_hpo.load.get_loaders(graph_dct: dict[str, list], batch_size=1, val_batch_size=1, test=False) dict[str, torch_geometric.loader.DataLoader]#

Get data loaders

Parameters:
  • graph_dct

  • batch_size

  • test

Returns:

Dictionary of data loaders