Back to Open Assistant

writing prompt augmentation data task

notebooks/data-augmentation/writing-prompt/writing_prompt.ipynb

0.0.120.9 KB
Original Source

writing prompt augmentation data task

Pipeline

The goal of this task was to auto-generate question/answer samples from writingPrompts to feed openAssistant. To do that we should standardize the way a prompt was written. Our choice was to set prompt templates which might turn the generation process feasible. Here are the templates we applied:

  • Base template: every prompt would have this sample.

User: write me a story about: {stripped_prompt} -> Rosey: Sure, here's a story about: {stripped_prompt}:\n{story}

where stripped_promt is the cleared prompt output by regex pattern to take out parts of a prompt that would not fit the template. And story is the actual answer to a prompt.

  • General constraints: a prompt whose constraint was found by regex pattern would have this also.

Base template, {stripped_constraint} -> Rosey: Sure, here's a story about: {stripped_prompt}, {stripped_constraint}:\n{story}

where stripped_constraint is the constraint found.

  • Answer beginning constraints: this constraint was imposed by the way the answer should start.

Base template, starting with: {beginning} -> Rosey: Sure, here's a story about: {stripped_prompt}, starting with: {beginning}:\n{story}

where beginning is the first sentence of a story.

  • Answer end constraints: this constraint was imposed by the way the answer should end.

Base template, ending with: {ending} -> Rosey: Sure, here's a story about {stripped_prompt}: ending with: {ending}\n{story}

where ending is the last sentence of a story.

  • Answer middle constraints: this constraint was imposed by the way the answer should have in its middle text.

Base template, where the middle of the story is about: {middle} -> Rosey: Sure, here's a story about: {stripped_prompt}, where the middle of the story is about: {middle}:\n{story}

where middle is a summary of a story without the first and last sentence brought by a generative model

To get the samples we used the following pipeline:

  • Get data: download from kaggle
  • Pre-processing: load data from entails source/taget (aka: prompt/story) by every split (train/valid/test) merging into one pandas dataframe, enhancing tit with tabular info about the sample tags.
  • Triage prompts: we pick prompts sorted by frequency, and we built regex pattern for some of them to extract a striped prompt and the related constraint.
  • Split stories: after removing story beginning and ending sentences, we applied a sentence sliding window to get stories middle summaries.

Get data from Kaggle

python
# helper functions
import json


def save_credentials(d):
    with open("/root/.kaggle/kaggle.json", "w") as outfile:
        json.dump(d, outfile)
python
# uncomment the following instructions, in case you want to save a .kaggle.json
# d = {}
# d['username'] = 'user'
# d['key'] = 'key'
#!mkdir ~/.kaggle
# save_credentials(d)
!mv ~/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
python
#!pip install kaggle
python
!kaggle datasets download -d ratthachat/writing-prompts
python
!unzip writing-prompts.zip

Pre-processing

python
import pandas as pd
from IPython.display import display, HTML
python
# helper functions
import re


def load_file(path, names):
    with open(path, "r") as f:
        lines = f.readlines()
    return pd.DataFrame(lines, columns=names)


def load_data():
    tags = {
        "WP": "Writing Prompt",
        "SP": "Simple Prompt",
        "EU": "Established Universe",
        "CW": "Constrained Writing",
        "TT": "Theme Thursday",
        "PM": "Prompt Me",
        "MP": "Media Prompt",
        "IP": "Image Prompt",
        "PI": "Prompt Inspired",
        "OT": "Off Topic",
        "RF": "Reality Fiction",
    }

    dfConcat = pd.DataFrame()
    for split in ["train", "valid", "test"]:
        df = load_file(f"writingPrompts/{split}.wp_source", ["prompt"])
        for tag in tags.keys():
            df[tag.lower()] = df["prompt"].map(lambda x: check_tag(x, tag.lower()))
        df["tagCounter"] = df.iloc[:, [2, -1]].sum(axis=1)
        df["splitLineIndex"] = df.index
        story = load_file(f"writingPrompts/{split}.wp_target", ["story"])
        df["story"] = story["story"]
        df["split"] = split
        dfConcat = pd.concat([dfConcat, df])
    return dfConcat


def check_tag(item, tag):
    r = re.compile(r"[\(\{\[]\s*[\w]{2}\s*[\]\}\)]\s*")
    m = r.findall(item.lower())
    if len(m) > 0:
        for group in m:
            if tag in group:
                return 1
    return 0


def show_data(df):
    html_string = "<"
    html_string += "html><"
    html_string += "head><title>HTML Pandas Dataframe with CSS</title></head"
    html_string += "><"
    html_string += 'link rel="stylesheet" type="text/css" href="df_style.css"/'
    html_string += "><"
    html_string += """body>
                    {table}
                  </body>
                </html
                """
    html_string += ">"
    df = df.replace("\<newline\>|\< newline \>|\<new line\>", "\n", regex=True)
    df.style.set_properties(**{"text-align": "left"}).set_table_styles(
        [dict(selector="th", props=[("text-align", "left")])]
    )
    html = df.to_html()
    html_string = html_string.format(table=html)
    html_string = (
        html_string.replace(r"\n", "
")
        .replace("<td>", '<td style="text-align:left">')
        .replace("<th>", '<th style="text-align:left">')
    )
    display(HTML(html_string))


def get_samples(df, n, constraint=None, show=True):
    samples = zip(df["prompt"].iloc[:n, 0].index, df["prompt"].iloc[:n, 0], df["story"].iloc[:n, 0])
    df = pd.DataFrame(samples, columns=["index", "prompt", "story"])
    if constraint is not None:
        df = df[df["prompt"].str.contains(constraint)]
    return df
python
!head -n2 writingPrompts/test.wp_source
python
ds = load_data()
python
ds.head(3)
python
print(ds.shape)
python
ds[ds["split"] == "test"].iloc[:2, [13, 0, 14, -1]].columns

Samples

Train

python
show_data(ds[ds["split"] == "train"].iloc[:2][["splitLineIndex", "prompt", "story", "split"]]);

Valid

python
show_data(ds[ds["split"] == "valid"].iloc[:2][["splitLineIndex", "prompt", "story", "split"]]);

Test

python
show_data(ds[ds["split"] == "test"].iloc[:2][["splitLineIndex", "prompt", "story", "split"]]);

Augmentation

python
from tqdm import tqdm

Triage Prompts

  1. Take the prompts list order by frequency
  2. Define regex patterns for prompt and constraint
  3. Generate prompts
python
df_rep = ds.groupby(["prompt", "split"]).size().reset_index().rename(columns={0: "records"})
python
df_rep = df_rep[df_rep["records"] > 20].sort_values(["records"], ascending=False)
# _str = df_rep[df_rep['records']>20].sort_values(['records'], ascending=False).iloc[1,0]
python
topPrompts20Reps = df_rep[df_rep["records"] > 20].sort_values(["records"], ascending=False)["prompt"].tolist()
python
topPrompts20Reps[:5]
python
# df_rep[df_rep["split"] == "valid"].iloc[1:3, 0]
# topPrompts20Reps += df_rep[df_rep["split"] == "valid"].iloc[1:3, 0].to_list()
python
print(f"We found {len(topPrompts20Reps)} prompts having more than 20 stories")
python
PROMPT_PATTERNS = "(Lucifer\snever[\s\w,]+)|\
([\. \w,]+)\.\s+Tell me|\
(All injuries[\. \w,]+)\.|\
(?<!\])(At your[\. \w,]+)\.|\
Daily Prompt \: ([\. \w,]+)|\
In 100 words or less , ([\. \w,]+)\.|\
(Last words/thoughts[\. \w,]+)\.|\
(Magic is Hereditary.*) \[|\
word limit (\) [\. \w,\/]+) \.|\
(Make me love the person you love)|\
(Pack a punch) in 150 words|\
(The last man on earth[\. \w,\/]+kill himself)|\
(The year is 2352 [\. \w,\/'-]+)\.|\
(A person dies[\. \w,\/]+)\.?|\
^[wW]rite a story([\. \w,\/]+) |\
^[wW]rite about ([\. \w,\/-]+)\.?|\
^Writing Prompt (?:\: [wW]rite|\
\[ WP \]) ([\. \w,\/']+) ?|\
^(You 're a[\. \w,\/']+)|\
(You 're moments[\. \w,\/']+)\.|\
(Describe the room you [\. \w\/']+)|\
 (Get me hooked \. [ \w,\/']+)|\
[\. \w\/',\`]+ , (tell a horror story)|\
(Make me cry)|\
(Make me hate your character)|\
(Most responses on here have a twist[\. \w\/',\`;]+)|\
(Pick your favorite[\(\)\. \w\/',\`;]+beginning)|\
(Start your story[\(\)\. \w\/',\`;]+meanings \.)|\
(The [\. \w\/',\`;]+ reader)|\
(Two people[\. \w,\/']+bench)|\
Write (a gruesome story)|\
Write (a möb[\. \w,\/']+story) that|\
(Write the letter [ ,\w]+) |\
There is no prompt[ \.\w]+(you[ \.\w']+\.)|\
(A peaceful alien race[ \.\w'-]+)\.|\
(This is the prologue[\(\) \.\w'-]+)\.|\
Write a short story where (the first[\(\) \.\w'-,]+)\.|\
(Write the first and last paragraph[\(\) \.\w'-,]+)\.|\
(Killing Hitler has[\(\) \.\w'-,\?]+)|\
(You live in a city full[\(\) \.\w'-,\?\#]+)|\
\`\` She said she loved him . [\`'\(\) \.\w'-,\?\#]+\.|\
(A soldier on the front dies[\(\) \.\w'-,\?\#]+)|\
(You discover a grand hall[\(\) \.\w'-,\?\#]+)|\
(A boy asks a girl out . It 's high[\(\) \.\w'-,\?\#]+)|\
(When everyone turns 18 , they receive a pet[\(\) \.\w'-,\?\#]+)|\
(To get in Heaven , you have to [\/\(\) \.\w'-,\?\#]+)|\
(You are born without emotions [;\/\(\) \.\w'-,\?\#]+)|\
(You are a teenager with the ability[\`;\/\(\) \.\w'-,\?\#]+)|\
(You live in a world where every person [\`;\/\(\) \.\w'-,\?\#]+)"


CONST_PATTERNS = "Daily Prompt \: [\. \w,]+\[ ([\. \w,\:]+)|\
(In 100 words or less) , ([\. \w,\:]+) \.|\
Make a story \( ([\. \w,\:]+) |\
Pack a punch (in 150 words)|\
Describe the room you [\. \w\/']+([\. \w,\:\/]+)\.|\
Get me hooked \. Reel me in \. ([\. \w\/',\`]+)\.|\
 ([\. \w\/',\`]+) , tell a horror story|\
Make me cry ([ \w\/',\`]+).?|\
(in 150 words or less)|\
Pick your favorite[\(\)\. \w\/',\`;]+beginning \. ([ \w\/',\`]+)|\
Start your story[\(\)\. \w\/',\`;]+meanings \.([ \w\/',\`]+\.)|\
The [\. \w\/',\`;]+ reader ,([\. \w\/',\`;]+)|\
Two people[\. \w,\/']+bench \. ([\. \w,\:]+)|\
Write a gruesome story ([\. \w,\:]+)|\
Write a möb[\. \w,\/']+story (that[\. \w,\/']+)"

Add summary columns to data

python
#!pip install spacy -qqq

We aim to augment data as following:

  • Prompt:
    • whole
      • constraints
  • Story:
    • whole
    • beginning
    • middle - sliding window summarized
    • end

Summarization

python
#!pip install transformers
python
# @markdown utils
from transformers.utils.logging import set_verbosity

set_verbosity(40)

import warnings

# ignore hf pipeline complaints
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
python
import torch
from transformers import pipeline

summarizer = pipeline(
    "summarization",
    "pszemraj/long-t5-tglobal-base-16384-book-summary",
    device=0 if torch.cuda.is_available() else -1,
)
python
params = {
    "max_length": 1024,
    "min_length": 8,
    "no_repeat_ngram_size": 3,
    "early_stopping": False,
    "repetition_penalty": 3.5,
    "length_penalty": 0.3,
    "encoder_no_repeat_ngram_size": 3,
    "num_beams": 4,
}  # parameters for text generation out of model

Interpolation

python
import spacy
python
# helper functions

import re


def extract_prompt_parts(prompt, pattern):
    """
    takes a prompt and some parts that matches to patern
    """
    pattern = pattern.replace("\\\n", "\\")
    if m := re.search(pattern, prompt, re.IGNORECASE):
        if len(m.groups()) > 0:
            return m.group(0)
    return None


from spacy.lang.en import English


def get_sentences(_str):
    chunks = _str.split("\n")
    sentences = []
    nlp = English()
    nlp.add_pipe("sentencizer")
    for chunk in chunks:
        doc = nlp(chunk)
        sentences += [sent.text.strip() for sent in doc.sents]
    return sentences


from itertools import islice


def window(seq, n=2):
    it = iter(seq)
    result = tuple(islice(it, n))
    if len(result) == n:
        yield " ".join(result)
    for elem in it:
        result = result[1:] + (elem,)
        yield " ".join(result)


def extract_story_parts(story):
    sentences = get_sentences(story)
    beginning = sentences.pop(0)
    middles = window(sentences, 4)
    ending = sentences.pop(-1)
    return beginning, middles, ending


def clear_prompt(prompt):
    return re.sub(r"^[Ww]rite ", "", prompt)


def get_sample_dict(split, id, text):
    return {"split": split, "splitLineIndex": id, "text": text}


def generate_instruction_diologs(df):
    dialogs = []
    """User: What is this story about: {story} -> Rosey: I think it's about: {striped_prompt}"""
    dialogBase = """User: write me a story about: {stripped_prompt}"""
    dialog1 = """ -> Rosey: Sure, here's a story about: {stripped_prompt}:\n{story}"""
    dialog2 = """, {stripped_constraint} -> Rosey: Sure, here's a story about: {stripped_prompt}, {stripped_constraint}:\n{story}"""
    dialog3 = """, starting with: {beginning} -> Rosey: Sure, here's a story about: {stripped_prompt}, starting with: {beginning}:\n{story}"""
    dialog4 = """, ending with: {ending} -> Rosey: Sure, here's a story about {stripped_prompt}: ending with: {ending}\n{story}"""
    dialog5 = """, where the middle of the story is about: {middle} -> Rosey: Sure, here's a story about: {stripped_prompt}, where the middle of the story is about: {middle}:\n{story}"""

    df_rep = df.groupby(["prompt"]).size().reset_index().rename(columns={0: "records"})
    df_rep.sort_values(["records"], ascending=False, inplace=True)
    pbar = tqdm()
    pbar.reset(total=len(df_rep))
    for prompt in df_rep.iloc[:, 0]:
        strippedPrompt = extract_prompt_parts(prompt, PROMPT_PATTERNS)
        if strippedPrompt is None:
            continue
        strippedPrompt = clear_prompt(strippedPrompt)
        strippedConstraint = extract_prompt_parts(prompt, CONST_PATTERNS)

        for row in df[df["prompt"] == prompt].itertuples():
            try:
                story = (
                    row.story.replace("<newline>", "\n")
                    .replace("< newline >", "\n")
                    .replace("<new line>", "\n")
                    .strip()
                )
                beginning, middles, ending = extract_story_parts(story)
                dialogBeg = dialogBase.format(stripped_prompt=strippedPrompt)
                dialog = dialogBeg + dialog1.format(story=story, stripped_prompt=strippedPrompt)
                dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
                if strippedConstraint is not None:
                    dialog = dialogBeg + dialog2.format(
                        stripped_prompt=strippedPrompt, stripped_constraint=strippedConstraint, story=story
                    )
                    dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
                dialog = dialogBeg + dialog3.format(stripped_prompt=strippedPrompt, story=story, beginning=beginning)
                dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
                dialog = dialogBeg + dialog4.format(stripped_prompt=strippedPrompt, story=story, ending=ending)
                dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
                middlesSumarizered = summarizer(middles, **params)
                for middle, sumarizedMiddle in zip(middles, middlesSumarizered):
                    # dialogs.append(dialogBeg + dialog5.format(stripped_prompt=strippedPrompt, story=story, middle=middle))
                    dialog = dialogBeg + dialog5.format(
                        stripped_prompt=strippedPrompt, story=story, middle=sumarizedMiddle[0]["summary_text"]
                    )
                    dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
                pbar.update()
            except Exception as e:
                print(f"{row.split}/{row.splitIndex}")
                raise e
        pbar.refresh()
    return dialogs


def filter_data(
    dataset,
    negativeTagFilter=None,
    positiveTagFilter=None,
    patternFilter=None,
):
    """
    > filter_data(dataset['train'],negativeTagFilter=['ip'], positiveTagFilter=['pm'] )
    """
    prompt = dataset["prompt"]
    if negativeTagFilter is not None:
        prompt = prompt[(prompt[negativeTagFilter] < 1).any(axis=1)]
    if positiveTagFilter is not None:
        prompt = prompt[prompt[positiveTagFilter].gt(0).all(axis=1)]
    if patternFilter is not None:
        prompt = prompt[prompt["prompt"].str.contains(patternFilter)]
    story = dataset["story"]
    story = story.iloc[prompt.index]
    return {"prompt": prompt, "story": story}


def generate_instruction_diologs(prompt, df):
    dialogs = []
    """User: What is this story about: {story} -> Rosey: I think it's about: {striped_prompt}"""
    dialogBase = """User: write me a story about: {stripped_prompt}"""
    dialog1 = """ -> Rosey: Sure, here's a story about: {stripped_prompt}:\n{story}"""
    dialog2 = """, {stripped_constraint} -> Rosey: Sure, here's a story about: {stripped_prompt}, {stripped_constraint}:\n{story}"""
    dialog3 = """, starting with: {beginning} -> Rosey: Sure, here's a story about: {stripped_prompt}, starting with: {beginning}:\n{story}"""
    dialog4 = """, ending with: {ending} -> Rosey: Sure, here's a story about {stripped_prompt}: ending with: {ending}\n{story}"""
    dialog5 = """, where the middle of the story is about: {middle} -> Rosey: Sure, here's a story about: {stripped_prompt}, where the middle of the story is about: {middle}:\n{story}"""

    strippedPrompt = extract_prompt_parts(prompt, PROMPT_PATTERNS)
    if strippedPrompt is not None:
        strippedPrompt = clear_prompt(strippedPrompt)
        strippedConstraint = extract_prompt_parts(prompt, CONST_PATTERNS)
        pbar = tqdm(ascii=True, desc="stories")
        pbar.reset(total=len(df[df["prompt"] == prompt]))
        for row in df[df["prompt"] == prompt].itertuples():
            try:
                story = (
                    row.story.replace("<newline>", "\n")
                    .replace("< newline >", "\n")
                    .replace("<new line>", "\n")
                    .strip()
                )
                dialogBeg = dialogBase.format(stripped_prompt=strippedPrompt)
                dialog = dialogBeg + dialog1.format(story=story, stripped_prompt=strippedPrompt)
                dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
                if strippedConstraint is not None:
                    dialog = dialogBeg + dialog2.format(
                        stripped_prompt=strippedPrompt, stripped_constraint=strippedConstraint, story=story
                    )
                    dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
                beginning, middles, ending = extract_story_parts(story)
                if beginning is not None:
                    beginning, middles, ending = extract_story_parts(story)
                    dialog = dialogBeg + dialog3.format(
                        stripped_prompt=strippedPrompt, story=story, beginning=beginning
                    )
                    dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
                    dialog = dialogBeg + dialog4.format(stripped_prompt=strippedPrompt, story=story, ending=ending)
                    dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
                    middlesSumarizered = summarizer(middles, **params)
                    for middle, sumarizedMiddle in zip(middles, middlesSumarizered):
                        # dialogs.append(dialogBeg + dialog5.format(stripped_prompt=strippedPrompt, story=story, middle=middle))
                        dialog = dialogBeg + dialog5.format(
                            stripped_prompt=strippedPrompt, story=story, middle=sumarizedMiddle[0]["summary_text"]
                        )
                        dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
                pbar.update()
            except Exception as e:
                print(f"{row.split}/{row.splitLineIndex}")
                raise e
            pbar.refresh()
    return dialogs

Generate

It saves parquet every step samples to avoid losing work.

python
## filter dataset to take only prompts with frequency greater than 20 stories.
dialogs = []
i = 0
start = 0
step = 10
for index in range(start, len(topPrompts20Reps), step):
    pbar = tqdm(ascii=True, desc="prompt")
    pbar.reset(total=len(topPrompts20Reps[index : index + step]))
    for prompt in topPrompts20Reps[index : index + step]:
        tmpDialogs = generate_instruction_diologs(prompt, ds)
        if tmpDialogs is not None:
            dialogs += tmpDialogs
        pbar.update()
    if len(dialogs) > 0:
        pd.DataFrame(dialogs).to_parquet("writing-prompts-aug.parquet")
    pbar.refresh()
python
df = pd.read_parquet("writing-prompts-aug.parquet")
python
for split in list(set(df.split)):
    df_aux = df[df["split"] == split].iloc[:, 1:]
    df_aux.reset_index(inplace=True)
    df_aux.iloc[:, 1:].to_parquet(f"{split}.parquet")