Back to Open Assistant

Prosocial

data/datasets/nsfw_selfharm_reddit/prosocial.ipynb

0.0.14.4 KB
Original Source
python
from datasets import load_dataset
import json
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer

SBERT_MODEL = "all-MiniLM-L6-v2"
from collections import Counter
import nltk
import re
from nltk import sent_tokenize
python
nsfw_dataset = load_dataset("jjmachan/NSFW-questions-inter-cleaned_df", split="train")
pro_social_dataset = load_dataset("allenai/prosocial-dialog", split="train")
python
nsfw_dataset
python
def match_rot_safetylabels(dataset):
    rots = [item["rots"] for item in dataset]
    safety_annotations = [item["safety_label"] for item in dataset]
    results = {}
    for rots, sfty in zip(rots, safety_annotations):
        for rot in rots:
            if rot not in results.keys():
                results[rot] = sfty
    return results
python
rot_sfty = match_rot_safetylabels(pro_social_dataset)
all_rots = list(set(rot_sfty.keys()))
python
def load_vectorizer(model=SBERT_MODEL):
    return SentenceTransformer(model)


def vectorize_text(model, texts):
    return model.encode(texts, show_progress_bar=True)
python
model = load_vectorizer()
python
rot_vector = vectorize_text(model, all_rots)
python
import scipy.spatial as sp
from collections import defaultdict
from tqdm import tqdm

THRESHOLD = 0.65


def match_query_rot(q, m):
    cosine_sim = 1 - sp.distance.cdist(q, m, "cosine")
    sim_indices = np.argwhere(cosine_sim >= THRESHOLD)
    return sim_indices
python
BATCH_SIZE = 100


def match_rot_post(dataset):
    dic = {}
    posts = [item["title"] for item in dataset]
    post_vector = vectorize_text(model, posts)
    for idx in tqdm(range(0, len(post_vector), BATCH_SIZE)):
        sim_indices = match_query_rot(post_vector[idx : idx + BATCH_SIZE], rot_vector)
        for post_idx, rot_idx in sim_indices:
            rot = all_rots[rot_idx]
            dic.update({dataset[int(post_idx) + idx]["post_id"]: {"rots": [rot], "safety_label": rot_sfty.get(rot)}})
    return dic
python
result_dict = match_rot_post(nsfw_dataset)
python
print("Turaround perc", len(result_dict) / len(nsfw_dataset) * 100)
python
def filter_stopwords(example):
    stopwords = ["Ladies", "Women", "Gals", "Men", "guys"]
    regex = "".join([f"{word}(,)?|" for word in stopwords])
    example["title"] = re.sub(regex, "", example["title"], flags=re.IGNORECASE)
    return example
python
def add_rot_label(example):
    post_id = example["post_id"]
    if post_id in result_dict.keys():
        example["rots"] = result_dict.get(post_id)["rots"]
        example["safety_label"] = result_dict.get(post_id)["safety_label"]

    return example
python
def select_response(example):
    comments = comments_df[comments_df["post_id"] == example["post_id"]][
        ["C1", "C2", "C3", "C4", "C5"]
    ].values.tolist()[0]
    comments = [str(comment) for comment in comments]
    comments = [
        comment for comment in comments if (len(sent_tokenize(comment)) > 1) and (len(sent_tokenize(comment)) < 3)
    ]
    comments = [comment for comment in comments if re.search("(?P<url>https?://[^\s]+)", comment) is None]

    if comments:
        example["response"] = np.random.choice(comments, 1)[0]
        print(example["response"])

    return example
python
new_column = [[]] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("rots", new_column)
new_column = [None] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("safety_label", new_column)
new_column = ["None"] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("response", new_column)
python
nsfw_dataset = nsfw_dataset.map(filter_stopwords)
nsfw_dataset = nsfw_dataset.map(add_rot_label)
nsfw_dataset = nsfw_dataset.filter(lambda example: example["safety_label"])

Comments

python
post_ids = [item["post_id"] for item in nsfw_dataset]
comments_df = get_comments(post_ids)
python
nsfw_dataset = nsfw_dataset.map(select_response)
python
nsfw_dataset = nsfw_dataset.rename_columns({"title": "user"})
nsfw_dataset = nsfw_dataset.remove_columns(["C1", "C2", "C3", "C4", "C5", "link_flair_text", "score", "upvote_ratio"])
python
nsfw_dataset = nsfw_dataset.shuffle()
nsfw_dataset[35]
python
new_column = [True] * len(nsfw_dataset)
nsfw_dataset = nsfw_dataset.add_column("episode_done", new_column)
python
nsfw_dataset.push_to_hub("shahules786/prosocial-nsfw")