Small Scale Link Prediction (FB15K-237)
In this tutorial, we use the FB15K_237 knowledge graph as an example to demonstrate a step-by-step walkthrough from preprocessing the dataset to defining the configuration file and to training a link prediction model with the DistMult algorithm.
1. Preprocess Dataset
Preprocessing a dataset is straightforward with the marius_preprocess
command. This command comes with marius
when marius
is installed. See (TODO link) for installation information.
Assuming marius_preprocess
has been built, we preprocess the FB15K_237 dataset by running the following command (assuming we are in the marius
root directory):
$ marius_preprocess --dataset fb15k_237 --output_directory datasets/fb15k_237_example/
Downloading FB15K-237.2.zip to datasets/fb15k_237_example/FB15K-237.2.zip
Reading edges
Remapping Edges
Node mapping written to: datasets/fb15k_237_example/nodes/node_mapping.txt
Relation mapping written to: datasets/fb15k_237_example/edges/relation_mapping.txt
Dataset statistics written to: datasets/fb15k_237_example/dataset.yaml
The --dataset
flag specifies which of the pre-set datasets marius_preprocess
will preprocess and download.
The --output_directory
flag specifies where the preprocessed graph will be output and is set by the user. In this example, assume we have not created the datasets/fb15k_237_example repository. marius_preprocess
will create it for us.
For detailed usages of marius_preprocess
, please execute the following command:
$ marius_preprocess -h
Let’s check what is inside the created directory:
$ ls -l datasets/fb15k_237_example/
dataset.yaml # input dataset statistics
nodes/
node_mapping.txt # mapping of raw node ids to integer uuids
edges/
relation_mapping.txt # mapping of raw edge(relation) ids to integer uuids
test_edges.bin # preprocessed testing edge list
train_edges.bin # preprocessed training edge list
validation_edges.bin # preprocessed validation edge list
train.txt # raw training edge list
test.txt # raw testing edge list
valid.txt # raw validation edge list
text_cvsc.txt # relation triples as used in Toutanova and Chen CVSM-2015
text_emnlp.txt # relation triples as used inToutanova et al. EMNLP-2015
README.txt # README of the downloaded FB15K-237 dataset
Let’s check what is inside the generated dataset.yaml
file:
$ cat datasets/fb15k_237_example/dataset.yaml
dataset_dir: /marius-internal/datasets/fb15k_237_example/
num_edges: 272115
num_nodes: 14541
num_relations: 237
num_train: 272115
num_valid: 17535
num_test: 20466
node_feature_dim: -1
rel_feature_dim: -1
num_classes: -1
initialized: false
Note
If the above marius_preprocess
command fails due to any missing directory errors, please create the <output_directory>/edges
and <output_directory>/nodes
directories as a workaround.
2. Define Configuration File
To train a model, we need to define a YAML configuration file based on information created from marius_preprocess.
The configuration file contains information including but not limited to the inputs to the model, training procedure, and hyperparameters to optimize. Given a configuration file, marius assembles a model depending on the given parameters. The configuration file is grouped up into four sections:
Model: Defines the architecture of the model, neighbor sampling configuration, loss, and optimizer(s)
Storage: Specifies the input dataset and how to store the graph, features, and embeddings.
Training: Sets options for the training procedure and hyperparameters. E.g. batch size, negative sampling.
Evaluation: Sets options for the evaluation procedure (if any). The options here are similar to those in the training section.
For the full configuration schema, please refer to docs/config_interface
.
An example YAML configuration file for the FB15K_237 dataset is given in examples/configuration/fb15k_237.yaml
. Note that the dataset_dir
is set to the preprocessing output directory, in our example, datasets/fb15k_237_example/
.
Let’s create the same YAML configuration file for the FB15K_237 dataset from scratch. We follow the structure of the configuration file and create each of the four sections one by one. In a YAML file, indentation is used to denote nesting and all parameters are in the format of key-value pairs.
First, we define the model. We begin by setting all required parameters. This includes
learning_task
,encoder
,decoder
, andloss
. The rest of the configurations can be fine-tuned by the user.model: learning_task: LINK_PREDICTION # set the learning task to link prediction encoder: layers: - - type: EMBEDDING # set the encoder to be an embedding table with 50-dimensional embeddings output_dim: 50 decoder: type: DISTMULT # set the decoder to DistMult options: input_dim: 50 loss: type: SOFTMAX_CE options: reduction: SUM dense_optimizer: # optimizer to use for dense model parameters. In this case these are the DistMult relation (edge-type) embeddings type: ADAM options: learning_rate: 0.1 sparse_optimizer: # optimizer to use for node embedding table type: ADAGRAD options: learning_rate: 0.1 storage: # omit training: # omit evaluation: # omit
Next, we set the storage and dataset. We begin by setting all required parameters. This includes
dataset
. Here, thedataset_dir
is set todatasets/fb15k_237_example/
, which is the preprocessing output directory.model: # omit storage: device_type: cuda dataset: dataset_dir: datasets/fb15k_237_example/ edges: type: DEVICE_MEMORY embeddings: type: DEVICE_MEMORY save_model: true training: # omit evaluation: # omit
Lastly, we configure training and evaluation. We begin by setting all required parameters. This includes
num_epochs
andnegative_sampling
. We setnum_epochs=10
(10 epochs to train) to demonstrate this example. Note thatnegative_sampling
is required for link prediction.model: # omit storage: # omit training: batch_size: 1000 negative_sampling: num_chunks: 10 negatives_per_positive: 500 degree_fraction: 0.0 filtered: false num_epochs: 10 pipeline: sync: true epochs_per_shuffle: 1 evaluation: batch_size: 1000 negative_sampling: filtered: true pipeline: sync: true
3. Train Model
After defining our configuration file, training is run with marius_train <your_config.yaml>
.
We can now train our example using the configuration file we just created by running the following command (assuming we are in the marius
root directory):
$ marius_train datasets/fb15k_237_example/fb15k_237.yaml
[2022-04-03 14:53:15.106] [info] [marius.cpp:45] Start initialization
[04/03/22 14:53:19.140] Initialization Complete: 4.034s
[04/03/22 14:53:19.147] ################ Starting training epoch 1 ################
[04/03/22 14:53:19.224] Edges processed: [28000/272115], 10.29%
[04/03/22 14:53:19.295] Edges processed: [56000/272115], 20.58%
[04/03/22 14:53:19.369] Edges processed: [84000/272115], 30.87%
[04/03/22 14:53:19.447] Edges processed: [112000/272115], 41.16%
[04/03/22 14:53:19.525] Edges processed: [140000/272115], 51.45%
[04/03/22 14:53:19.603] Edges processed: [168000/272115], 61.74%
[04/03/22 14:53:19.685] Edges processed: [196000/272115], 72.03%
[04/03/22 14:53:19.765] Edges processed: [224000/272115], 82.32%
[04/03/22 14:53:19.851] Edges processed: [252000/272115], 92.61%
[04/03/22 14:53:19.906] Edges processed: [272115/272115], 100.00%
[04/03/22 14:53:19.906] ################ Finished training epoch 1 ################
[04/03/22 14:53:19.906] Epoch Runtime: 758ms
[04/03/22 14:53:19.906] Edges per Second: 358990.75
[04/03/22 14:53:19.906] Evaluating validation set
[04/03/22 14:53:19.972]
=================================
Link Prediction: 35070 edges evaluated
Mean Rank: 443.786313
MRR: 0.233709
Hits@1: 0.157998
Hits@3: 0.258597
Hits@5: 0.308640
Hits@10: 0.382407
Hits@50: 0.560137
Hits@100: 0.633191
=================================
[04/03/22 14:53:19.972] Evaluating test set
[04/03/22 14:53:20.043]
=================================
Link Prediction: 40932 edges evaluated
Mean Rank: 454.272940
MRR: 0.230645
Hits@1: 0.155282
Hits@3: 0.253103
Hits@5: 0.304065
Hits@10: 0.382073
Hits@50: 0.559758
Hits@100: 0.630192
=================================
After running this configuration for 10 epochs, we should see a result similar to below with a MRR roughly equal to 0.25:
=================================
[04/03/22 14:53:27.861] ################ Starting training epoch 10 ################
[04/03/22 14:53:27.944] Edges processed: [28000/272115], 10.29%
[04/03/22 14:53:28.023] Edges processed: [56000/272115], 20.58%
[04/03/22 14:53:28.115] Edges processed: [84000/272115], 30.87%
[04/03/22 14:53:28.220] Edges processed: [112000/272115], 41.16%
[04/03/22 14:53:28.315] Edges processed: [140000/272115], 51.45%
[04/03/22 14:53:28.410] Edges processed: [168000/272115], 61.74%
[04/03/22 14:53:28.506] Edges processed: [196000/272115], 72.03%
[04/03/22 14:53:28.602] Edges processed: [224000/272115], 82.32%
[04/03/22 14:53:28.699] Edges processed: [252000/272115], 92.61%
[04/03/22 14:53:28.772] Edges processed: [272115/272115], 100.00%
[04/03/22 14:53:28.772] ################ Finished training epoch 10 ################
[04/03/22 14:53:28.772] Epoch Runtime: 911ms
[04/03/22 14:53:28.772] Edges per Second: 298699.22
[04/03/22 14:53:28.772] Evaluating validation set
[04/03/22 14:53:28.834]
=================================
Link Prediction: 35070 edges evaluated
Mean Rank: 303.712946
MRR: 0.259462
Hits@1: 0.173253
Hits@3: 0.286570
Hits@5: 0.348104
Hits@10: 0.434474
Hits@50: 0.626775
Hits@100: 0.706045
=================================
[04/03/22 14:53:28.835] Evaluating test set
[04/03/22 14:53:28.904]
=================================
Link Prediction: 40932 edges evaluated
Mean Rank: 317.841664
MRR: 0.255330
Hits@1: 0.169794
Hits@3: 0.281858
Hits@5: 0.341860
Hits@10: 0.429859
Hits@50: 0.625208
Hits@100: 0.703875
=================================
Let’s check again what was added in the datasets/fb15k_237_example/
directory. For clarity, we only list the files that were created in training. Notice that several files have been created, including the trained model, the embedding table, a full configuration file, and output logs:
$ ls datasets/fb15k_237_example/
model.pt # contains the dense model parameters, embeddings of the edge-types
model_state.pt # optimizer state of the trained model parameters
full_config.yaml # detailed config generated based on user-defined config
metadata.csv # information about metadata
logs/ # logs containing output, error, debug, information
nodes/
embeddings.bin # trained node embeddings of the graph
embeddings_state.bin # node embedding optimizer state
...
edges/
...
...
Note
model.pt
contains the dense model parameters. For DistMult, this is the embeddings of the edge-types. For GNN encoders, this file will include the GNN parameters.