examples/07_tutorials/KDD2020-tutorial/step5_run_lightgcn.ipynb
<i>Copyright (c) Recommenders contributors.</i>
<i>Licensed under the MIT License.</i>
We offer an example to help readers to run a ID-based collaborative filtering baseline with LightGCN.
LightGCN is a simple and neat Graph Convolution Network (GCN) model for recommender systems. I It uses a GCN to learn the embeddings of users/items, with the goal that low-order and high-order user-item interactions are explicitly exploited into the embedding function.
The model architecture is illustrated as follows:
For more details and instructions, please refer to lightgcn_deep_dive.ipynb.
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from recommenders.utils.timer import Timer
from recommenders.models.deeprec.models.graphrec.lightgcn import LightGCN
from recommenders.models.deeprec.DataModel.ImplicitCF import ImplicitCF
from recommenders.datasets import movielens
from recommenders.datasets.python_splitters import python_stratified_split
from recommenders.evaluation.python_evaluation import map_at_k, ndcg_at_k, precision_at_k, recall_at_k
from recommenders.utils.constants import SEED as DEFAULT_SEED
from recommenders.models.deeprec.deeprec_utils import prepare_hparams
from recommenders.models.deeprec.deeprec_utils import cal_metric
from utils.general import *
from utils.data_helper import *
from utils.task_helper import *
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
tag = 'small'
lightgcn_dir = 'data_folder/my/LightGCN-training-folder'
rawdata_dir = 'data_folder/my/DKN-training-folder'
create_dir(lightgcn_dir)
First, we need to transform the raw dataset into LightGCN's input data format:
prepare_dataset(lightgcn_dir, rawdata_dir, tag)
df_train = pd.read_csv(
os.path.join(lightgcn_dir, 'lightgcn_train_{0}.txt'.format(tag)),
sep=' ',
engine="python",
names=['userID', 'itemID', 'rating'],
header=0
)
df_train.head()
LightGCN only takes positive user-item interactions for model training. Pairs with rating < 1 will be ignored by the model.
df_valid = pd.read_csv(
os.path.join(lightgcn_dir, 'lightgcn_valid_{0}.txt'.format(tag)),
sep=' ',
engine="python",
names=['userID', 'itemID', 'rating'],
header=0
)
data = ImplicitCF(
train=df_train, test=df_valid, seed=0,
col_user='userID',
col_item='itemID',
col_rating='rating'
)
yaml_file = './lightgcn.yaml'
hparams = prepare_hparams(yaml_file,
learning_rate=0.005,
eval_epoch=1,
top_k=10,
save_model=True,
epochs=15,
save_epoch=1
)
hparams.MODEL_DIR = os.path.join(lightgcn_dir, 'saved_models')
print(hparams.values())
model = LightGCN(hparams, data, seed=0)
with Timer() as train_time:
model.fit()
print("Took {} seconds for training.".format(train_time.interval))
user_emb_file = os.path.join(lightgcn_dir, 'user.emb.txt')
item_emb_file = os.path.join(lightgcn_dir, 'item.emb.txt')
model.infer_embedding(
user_emb_file,
item_emb_file
)
To compare LightGCN's performance with DKN, we need to make predictions on the same test set. So we infer the users/items embedding, then compute the similarity scores between each pairs of user-item in the test set.
def infer_scores_via_embeddings(test_filename, user_emb_file, item_emb_file):
print('loading embedding file...', end=' ')
user2vec = load_emb_file(user_emb_file)
item2vec = load_emb_file(item_emb_file)
preds, labels, groupids = [], [], []
with open(test_filename, 'r') as rd:
while True:
line = rd.readline()
if not line:
break
words = line.strip().split('%')
tokens = words[0].split(' ')
userid = words[1]
itemid = tokens[2]
pred = user2vec[userid].dot(item2vec[itemid])
preds.append(pred)
labels.append(int(tokens[0]))
groupids.append(userid)
print('done')
return labels, preds, groupids
test_filename = os.path.join(rawdata_dir, 'test_{}.txt'.format(tag))
labels, preds, group_keys = infer_scores_via_embeddings(test_filename, user_emb_file, item_emb_file)
group_labels, group_preds = group_labels(labels, preds, group_keys)
res_pairwise = cal_metric(
group_labels, group_preds, ['ndcg@2;4;6', "group_auc"]
)
print(res_pairwise)
res_pointwise = cal_metric(labels, preds, ['auc'])
print(res_pointwise)