load_dataset

class LoadDataset(Dataset)

Load the binary graphs.

Example:

import os
from agat.data import LoadDataset
dataset=LoadDataset(os.path.join('dataset', 'all_graphs.bin'))

# you can index or slice the dataset.
g0, props0 = dataset[0]
g_batch, props = dataset[0:100] # the g_batch is a batch collection of graphs. See https://docs.dgl.ai/en/1.1.x/generated/dgl.batch.html
__init__(self, dataset_path)
Parameters:

dataset_path (str) – A paths leads to the binary DGL graph file.

Returns:

a graph dataset.

Return type:

list

__getitem__(self, index)

Index or slice the dataset.

Parameters:

index (int/slice) – list index or slice

Returns:

graph or graph batch

Return type:

dgl graph

Returns:

props. Graph labels

Return type:

A dict of torch.tensor

__len__(self)

Get the length of the dataset.

Returns:

the length of the dataset

Return type:

int

class Collater(object)

The collate function used in torch.utils.data.DataLoader: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

The collate function determines how to merge the batch data.

Example:

import os
from agat.data import LoadDataset, Collater
from torch.utils.data import DataLoader

dataset=LoadDataset(os.path.join('dataset', 'all_graphs.bin'))
collate_fn = Collater(device='cuda')
data_loader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
__init__(self, device='cuda')
Parameters:

device (str, optional) – device to store the merged data, defaults to ‘cuda’

__call__(self, data)

Collate the data into batches.

Parameters:

data (tuple) – the output of LoadDataset

Returns:

dgl batch graphs. See https://docs.dgl.ai/en/1.1.x/generated/dgl.batch.html

Return type:

DGLGraph

Returns:

Graph labels

Return type:

A dict of torch.tensor