load_dataset¶
- class LoadDataset(Dataset)¶
Load the binary graphs.
Example:
import os from agat.data import LoadDataset dataset=LoadDataset(os.path.join('dataset', 'all_graphs.bin')) # you can index or slice the dataset. g0, props0 = dataset[0] g_batch, props = dataset[0:100] # the g_batch is a batch collection of graphs. See https://docs.dgl.ai/en/1.1.x/generated/dgl.batch.html
- __init__(self, dataset_path)¶
- Parameters:
dataset_path (str) – A paths leads to the binary DGL graph file.
- Returns:
a graph dataset.
- Return type:
list
- __getitem__(self, index)¶
Index or slice the dataset.
- Parameters:
index (int/slice) – list index or slice
- Returns:
graph or graph batch
- Return type:
dgl graph
- Returns:
props. Graph labels
- Return type:
A dict of torch.tensor
- __len__(self)¶
Get the length of the dataset.
- Returns:
the length of the dataset
- Return type:
int
- class Collater(object)¶
The collate function used in torch.utils.data.DataLoader: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
The collate function determines how to merge the batch data.
Example:
import os from agat.data import LoadDataset, Collater from torch.utils.data import DataLoader dataset=LoadDataset(os.path.join('dataset', 'all_graphs.bin')) collate_fn = Collater(device='cuda') data_loader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
- __init__(self, device='cuda')¶
- Parameters:
device (str, optional) – device to store the merged data, defaults to ‘cuda’
- __call__(self, data)¶
Collate the data into batches.
- Parameters:
data (tuple) – the output of
LoadDataset
- Returns:
dgl batch graphs. See https://docs.dgl.ai/en/1.1.x/generated/dgl.batch.html
- Return type:
DGLGraph
- Returns:
Graph labels
- Return type:
A dict of torch.tensor