data/datasets/nsfw_selfharm_reddit/prosocial.ipynb
from datasets import load_dataset
import json
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
SBERT_MODEL = "all-MiniLM-L6-v2"
from collections import Counter
import nltk
import re
from nltk import sent_tokenize
nsfw_dataset = load_dataset("jjmachan/NSFW-questions-inter-cleaned_df", split="train")
pro_social_dataset = load_dataset("allenai/prosocial-dialog", split="train")
nsfw_dataset
def match_rot_safetylabels(dataset):
rots = [item["rots"] for item in dataset]
safety_annotations = [item["safety_label"] for item in dataset]
results = {}
for rots, sfty in zip(rots, safety_annotations):
for rot in rots:
if rot not in results.keys():
results[rot] = sfty
return results
rot_sfty = match_rot_safetylabels(pro_social_dataset)
all_rots = list(set(rot_sfty.keys()))
def load_vectorizer(model=SBERT_MODEL):
return SentenceTransformer(model)
def vectorize_text(model, texts):
return model.encode(texts, show_progress_bar=True)
model = load_vectorizer()
rot_vector = vectorize_text(model, all_rots)
import scipy.spatial as sp
from collections import defaultdict
from tqdm import tqdm
THRESHOLD = 0.65
def match_query_rot(q, m):
cosine_sim = 1 - sp.distance.cdist(q, m, "cosine")
sim_indices = np.argwhere(cosine_sim >= THRESHOLD)
return sim_indices
BATCH_SIZE = 100
def match_rot_post(dataset):
dic = {}
posts = [item["title"] for item in dataset]
post_vector = vectorize_text(model, posts)
for idx in tqdm(range(0, len(post_vector), BATCH_SIZE)):
sim_indices = match_query_rot(post_vector[idx : idx + BATCH_SIZE], rot_vector)
for post_idx, rot_idx in sim_indices:
rot = all_rots[rot_idx]
dic.update({dataset[int(post_idx) + idx]["post_id"]: {"rots": [rot], "safety_label": rot_sfty.get(rot)}})
return dic
result_dict = match_rot_post(nsfw_dataset)
print("Turaround perc", len(result_dict) / len(nsfw_dataset) * 100)
def filter_stopwords(example):
stopwords = ["Ladies", "Women", "Gals", "Men", "guys"]
regex = "".join([f"{word}(,)?|" for word in stopwords])
example["title"] = re.sub(regex, "", example["title"], flags=re.IGNORECASE)
return example
def add_rot_label(example):
post_id = example["post_id"]
if post_id in result_dict.keys():
example["rots"] = result_dict.get(post_id)["rots"]
example["safety_label"] = result_dict.get(post_id)["safety_label"]
return example
def select_response(example):
comments = comments_df[comments_df["post_id"] == example["post_id"]][
["C1", "C2", "C3", "C4", "C5"]
].values.tolist()[0]
comments = [str(comment) for comment in comments]
comments = [
comment for comment in comments if (len(sent_tokenize(comment)) > 1) and (len(sent_tokenize(comment)) < 3)
]
comments = [comment for comment in comments if re.search("(?P<url>https?://[^\s]+)", comment) is None]
if comments:
example["response"] = np.random.choice(comments, 1)[0]
print(example["response"])
return example
new_column = [[]] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("rots", new_column)
new_column = [None] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("safety_label", new_column)
new_column = ["None"] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("response", new_column)
nsfw_dataset = nsfw_dataset.map(filter_stopwords)
nsfw_dataset = nsfw_dataset.map(add_rot_label)
nsfw_dataset = nsfw_dataset.filter(lambda example: example["safety_label"])
post_ids = [item["post_id"] for item in nsfw_dataset]
comments_df = get_comments(post_ids)
nsfw_dataset = nsfw_dataset.map(select_response)
nsfw_dataset = nsfw_dataset.rename_columns({"title": "user"})
nsfw_dataset = nsfw_dataset.remove_columns(["C1", "C2", "C3", "C4", "C5", "link_flair_text", "score", "upvote_ratio"])
nsfw_dataset = nsfw_dataset.shuffle()
nsfw_dataset[35]
new_column = [True] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("episode_done", new_column)
nsfw_dataset.push_to_hub("shahules786/prosocial-nsfw")