notebooks/data-augmentation/writing-prompt/writing_prompt.ipynb
The goal of this task was to auto-generate question/answer samples from writingPrompts to feed openAssistant. To do that we should standardize the way a prompt was written. Our choice was to set prompt templates which might turn the generation process feasible. Here are the templates we applied:
User: write me a story about: {stripped_prompt} -> Rosey: Sure, here's a story about: {stripped_prompt}:\n{story}
where stripped_promt is the cleared prompt output by regex pattern to take out parts of a prompt that would not fit the template. And story is the actual answer to a prompt.
Base template, {stripped_constraint} -> Rosey: Sure, here's a story about: {stripped_prompt}, {stripped_constraint}:\n{story}
where stripped_constraint is the constraint found.
Base template, starting with: {beginning} -> Rosey: Sure, here's a story about: {stripped_prompt}, starting with: {beginning}:\n{story}
where beginning is the first sentence of a story.
Base template, ending with: {ending} -> Rosey: Sure, here's a story about {stripped_prompt}: ending with: {ending}\n{story}
where ending is the last sentence of a story.
Base template, where the middle of the story is about: {middle} -> Rosey: Sure, here's a story about: {stripped_prompt}, where the middle of the story is about: {middle}:\n{story}
where middle is a summary of a story without the first and last sentence brought by a generative model
To get the samples we used the following pipeline:
# helper functions
import json
def save_credentials(d):
with open("/root/.kaggle/kaggle.json", "w") as outfile:
json.dump(d, outfile)
# uncomment the following instructions, in case you want to save a .kaggle.json
# d = {}
# d['username'] = 'user'
# d['key'] = 'key'
#!mkdir ~/.kaggle
# save_credentials(d)
!mv ~/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
#!pip install kaggle
!kaggle datasets download -d ratthachat/writing-prompts
!unzip writing-prompts.zip
import pandas as pd
from IPython.display import display, HTML
# helper functions
import re
def load_file(path, names):
with open(path, "r") as f:
lines = f.readlines()
return pd.DataFrame(lines, columns=names)
def load_data():
tags = {
"WP": "Writing Prompt",
"SP": "Simple Prompt",
"EU": "Established Universe",
"CW": "Constrained Writing",
"TT": "Theme Thursday",
"PM": "Prompt Me",
"MP": "Media Prompt",
"IP": "Image Prompt",
"PI": "Prompt Inspired",
"OT": "Off Topic",
"RF": "Reality Fiction",
}
dfConcat = pd.DataFrame()
for split in ["train", "valid", "test"]:
df = load_file(f"writingPrompts/{split}.wp_source", ["prompt"])
for tag in tags.keys():
df[tag.lower()] = df["prompt"].map(lambda x: check_tag(x, tag.lower()))
df["tagCounter"] = df.iloc[:, [2, -1]].sum(axis=1)
df["splitLineIndex"] = df.index
story = load_file(f"writingPrompts/{split}.wp_target", ["story"])
df["story"] = story["story"]
df["split"] = split
dfConcat = pd.concat([dfConcat, df])
return dfConcat
def check_tag(item, tag):
r = re.compile(r"[\(\{\[]\s*[\w]{2}\s*[\]\}\)]\s*")
m = r.findall(item.lower())
if len(m) > 0:
for group in m:
if tag in group:
return 1
return 0
def show_data(df):
html_string = "<"
html_string += "html><"
html_string += "head><title>HTML Pandas Dataframe with CSS</title></head"
html_string += "><"
html_string += 'link rel="stylesheet" type="text/css" href="df_style.css"/'
html_string += "><"
html_string += """body>
{table}
</body>
</html
"""
html_string += ">"
df = df.replace("\<newline\>|\< newline \>|\<new line\>", "\n", regex=True)
df.style.set_properties(**{"text-align": "left"}).set_table_styles(
[dict(selector="th", props=[("text-align", "left")])]
)
html = df.to_html()
html_string = html_string.format(table=html)
html_string = (
html_string.replace(r"\n", "
")
.replace("<td>", '<td style="text-align:left">')
.replace("<th>", '<th style="text-align:left">')
)
display(HTML(html_string))
def get_samples(df, n, constraint=None, show=True):
samples = zip(df["prompt"].iloc[:n, 0].index, df["prompt"].iloc[:n, 0], df["story"].iloc[:n, 0])
df = pd.DataFrame(samples, columns=["index", "prompt", "story"])
if constraint is not None:
df = df[df["prompt"].str.contains(constraint)]
return df
!head -n2 writingPrompts/test.wp_source
ds = load_data()
ds.head(3)
print(ds.shape)
ds[ds["split"] == "test"].iloc[:2, [13, 0, 14, -1]].columns
show_data(ds[ds["split"] == "train"].iloc[:2][["splitLineIndex", "prompt", "story", "split"]]);
show_data(ds[ds["split"] == "valid"].iloc[:2][["splitLineIndex", "prompt", "story", "split"]]);
show_data(ds[ds["split"] == "test"].iloc[:2][["splitLineIndex", "prompt", "story", "split"]]);
from tqdm import tqdm
df_rep = ds.groupby(["prompt", "split"]).size().reset_index().rename(columns={0: "records"})
df_rep = df_rep[df_rep["records"] > 20].sort_values(["records"], ascending=False)
# _str = df_rep[df_rep['records']>20].sort_values(['records'], ascending=False).iloc[1,0]
topPrompts20Reps = df_rep[df_rep["records"] > 20].sort_values(["records"], ascending=False)["prompt"].tolist()
topPrompts20Reps[:5]
# df_rep[df_rep["split"] == "valid"].iloc[1:3, 0]
# topPrompts20Reps += df_rep[df_rep["split"] == "valid"].iloc[1:3, 0].to_list()
print(f"We found {len(topPrompts20Reps)} prompts having more than 20 stories")
PROMPT_PATTERNS = "(Lucifer\snever[\s\w,]+)|\
([\. \w,]+)\.\s+Tell me|\
(All injuries[\. \w,]+)\.|\
(?<!\])(At your[\. \w,]+)\.|\
Daily Prompt \: ([\. \w,]+)|\
In 100 words or less , ([\. \w,]+)\.|\
(Last words/thoughts[\. \w,]+)\.|\
(Magic is Hereditary.*) \[|\
word limit (\) [\. \w,\/]+) \.|\
(Make me love the person you love)|\
(Pack a punch) in 150 words|\
(The last man on earth[\. \w,\/]+kill himself)|\
(The year is 2352 [\. \w,\/'-]+)\.|\
(A person dies[\. \w,\/]+)\.?|\
^[wW]rite a story([\. \w,\/]+) |\
^[wW]rite about ([\. \w,\/-]+)\.?|\
^Writing Prompt (?:\: [wW]rite|\
\[ WP \]) ([\. \w,\/']+) ?|\
^(You 're a[\. \w,\/']+)|\
(You 're moments[\. \w,\/']+)\.|\
(Describe the room you [\. \w\/']+)|\
(Get me hooked \. [ \w,\/']+)|\
[\. \w\/',\`]+ , (tell a horror story)|\
(Make me cry)|\
(Make me hate your character)|\
(Most responses on here have a twist[\. \w\/',\`;]+)|\
(Pick your favorite[\(\)\. \w\/',\`;]+beginning)|\
(Start your story[\(\)\. \w\/',\`;]+meanings \.)|\
(The [\. \w\/',\`;]+ reader)|\
(Two people[\. \w,\/']+bench)|\
Write (a gruesome story)|\
Write (a möb[\. \w,\/']+story) that|\
(Write the letter [ ,\w]+) |\
There is no prompt[ \.\w]+(you[ \.\w']+\.)|\
(A peaceful alien race[ \.\w'-]+)\.|\
(This is the prologue[\(\) \.\w'-]+)\.|\
Write a short story where (the first[\(\) \.\w'-,]+)\.|\
(Write the first and last paragraph[\(\) \.\w'-,]+)\.|\
(Killing Hitler has[\(\) \.\w'-,\?]+)|\
(You live in a city full[\(\) \.\w'-,\?\#]+)|\
\`\` She said she loved him . [\`'\(\) \.\w'-,\?\#]+\.|\
(A soldier on the front dies[\(\) \.\w'-,\?\#]+)|\
(You discover a grand hall[\(\) \.\w'-,\?\#]+)|\
(A boy asks a girl out . It 's high[\(\) \.\w'-,\?\#]+)|\
(When everyone turns 18 , they receive a pet[\(\) \.\w'-,\?\#]+)|\
(To get in Heaven , you have to [\/\(\) \.\w'-,\?\#]+)|\
(You are born without emotions [;\/\(\) \.\w'-,\?\#]+)|\
(You are a teenager with the ability[\`;\/\(\) \.\w'-,\?\#]+)|\
(You live in a world where every person [\`;\/\(\) \.\w'-,\?\#]+)"
CONST_PATTERNS = "Daily Prompt \: [\. \w,]+\[ ([\. \w,\:]+)|\
(In 100 words or less) , ([\. \w,\:]+) \.|\
Make a story \( ([\. \w,\:]+) |\
Pack a punch (in 150 words)|\
Describe the room you [\. \w\/']+([\. \w,\:\/]+)\.|\
Get me hooked \. Reel me in \. ([\. \w\/',\`]+)\.|\
([\. \w\/',\`]+) , tell a horror story|\
Make me cry ([ \w\/',\`]+).?|\
(in 150 words or less)|\
Pick your favorite[\(\)\. \w\/',\`;]+beginning \. ([ \w\/',\`]+)|\
Start your story[\(\)\. \w\/',\`;]+meanings \.([ \w\/',\`]+\.)|\
The [\. \w\/',\`;]+ reader ,([\. \w\/',\`;]+)|\
Two people[\. \w,\/']+bench \. ([\. \w,\:]+)|\
Write a gruesome story ([\. \w,\:]+)|\
Write a möb[\. \w,\/']+story (that[\. \w,\/']+)"
#!pip install spacy -qqq
We aim to augment data as following:
#!pip install transformers
# @markdown utils
from transformers.utils.logging import set_verbosity
set_verbosity(40)
import warnings
# ignore hf pipeline complaints
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
import torch
from transformers import pipeline
summarizer = pipeline(
"summarization",
"pszemraj/long-t5-tglobal-base-16384-book-summary",
device=0 if torch.cuda.is_available() else -1,
)
params = {
"max_length": 1024,
"min_length": 8,
"no_repeat_ngram_size": 3,
"early_stopping": False,
"repetition_penalty": 3.5,
"length_penalty": 0.3,
"encoder_no_repeat_ngram_size": 3,
"num_beams": 4,
} # parameters for text generation out of model
import spacy
# helper functions
import re
def extract_prompt_parts(prompt, pattern):
"""
takes a prompt and some parts that matches to patern
"""
pattern = pattern.replace("\\\n", "\\")
if m := re.search(pattern, prompt, re.IGNORECASE):
if len(m.groups()) > 0:
return m.group(0)
return None
from spacy.lang.en import English
def get_sentences(_str):
chunks = _str.split("\n")
sentences = []
nlp = English()
nlp.add_pipe("sentencizer")
for chunk in chunks:
doc = nlp(chunk)
sentences += [sent.text.strip() for sent in doc.sents]
return sentences
from itertools import islice
def window(seq, n=2):
it = iter(seq)
result = tuple(islice(it, n))
if len(result) == n:
yield " ".join(result)
for elem in it:
result = result[1:] + (elem,)
yield " ".join(result)
def extract_story_parts(story):
sentences = get_sentences(story)
beginning = sentences.pop(0)
middles = window(sentences, 4)
ending = sentences.pop(-1)
return beginning, middles, ending
def clear_prompt(prompt):
return re.sub(r"^[Ww]rite ", "", prompt)
def get_sample_dict(split, id, text):
return {"split": split, "splitLineIndex": id, "text": text}
def generate_instruction_diologs(df):
dialogs = []
"""User: What is this story about: {story} -> Rosey: I think it's about: {striped_prompt}"""
dialogBase = """User: write me a story about: {stripped_prompt}"""
dialog1 = """ -> Rosey: Sure, here's a story about: {stripped_prompt}:\n{story}"""
dialog2 = """, {stripped_constraint} -> Rosey: Sure, here's a story about: {stripped_prompt}, {stripped_constraint}:\n{story}"""
dialog3 = """, starting with: {beginning} -> Rosey: Sure, here's a story about: {stripped_prompt}, starting with: {beginning}:\n{story}"""
dialog4 = """, ending with: {ending} -> Rosey: Sure, here's a story about {stripped_prompt}: ending with: {ending}\n{story}"""
dialog5 = """, where the middle of the story is about: {middle} -> Rosey: Sure, here's a story about: {stripped_prompt}, where the middle of the story is about: {middle}:\n{story}"""
df_rep = df.groupby(["prompt"]).size().reset_index().rename(columns={0: "records"})
df_rep.sort_values(["records"], ascending=False, inplace=True)
pbar = tqdm()
pbar.reset(total=len(df_rep))
for prompt in df_rep.iloc[:, 0]:
strippedPrompt = extract_prompt_parts(prompt, PROMPT_PATTERNS)
if strippedPrompt is None:
continue
strippedPrompt = clear_prompt(strippedPrompt)
strippedConstraint = extract_prompt_parts(prompt, CONST_PATTERNS)
for row in df[df["prompt"] == prompt].itertuples():
try:
story = (
row.story.replace("<newline>", "\n")
.replace("< newline >", "\n")
.replace("<new line>", "\n")
.strip()
)
beginning, middles, ending = extract_story_parts(story)
dialogBeg = dialogBase.format(stripped_prompt=strippedPrompt)
dialog = dialogBeg + dialog1.format(story=story, stripped_prompt=strippedPrompt)
dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
if strippedConstraint is not None:
dialog = dialogBeg + dialog2.format(
stripped_prompt=strippedPrompt, stripped_constraint=strippedConstraint, story=story
)
dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
dialog = dialogBeg + dialog3.format(stripped_prompt=strippedPrompt, story=story, beginning=beginning)
dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
dialog = dialogBeg + dialog4.format(stripped_prompt=strippedPrompt, story=story, ending=ending)
dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
middlesSumarizered = summarizer(middles, **params)
for middle, sumarizedMiddle in zip(middles, middlesSumarizered):
# dialogs.append(dialogBeg + dialog5.format(stripped_prompt=strippedPrompt, story=story, middle=middle))
dialog = dialogBeg + dialog5.format(
stripped_prompt=strippedPrompt, story=story, middle=sumarizedMiddle[0]["summary_text"]
)
dialogs.append(get_sample_dict(row.split, row.splitIndex, dialog))
pbar.update()
except Exception as e:
print(f"{row.split}/{row.splitIndex}")
raise e
pbar.refresh()
return dialogs
def filter_data(
dataset,
negativeTagFilter=None,
positiveTagFilter=None,
patternFilter=None,
):
"""
> filter_data(dataset['train'],negativeTagFilter=['ip'], positiveTagFilter=['pm'] )
"""
prompt = dataset["prompt"]
if negativeTagFilter is not None:
prompt = prompt[(prompt[negativeTagFilter] < 1).any(axis=1)]
if positiveTagFilter is not None:
prompt = prompt[prompt[positiveTagFilter].gt(0).all(axis=1)]
if patternFilter is not None:
prompt = prompt[prompt["prompt"].str.contains(patternFilter)]
story = dataset["story"]
story = story.iloc[prompt.index]
return {"prompt": prompt, "story": story}
def generate_instruction_diologs(prompt, df):
dialogs = []
"""User: What is this story about: {story} -> Rosey: I think it's about: {striped_prompt}"""
dialogBase = """User: write me a story about: {stripped_prompt}"""
dialog1 = """ -> Rosey: Sure, here's a story about: {stripped_prompt}:\n{story}"""
dialog2 = """, {stripped_constraint} -> Rosey: Sure, here's a story about: {stripped_prompt}, {stripped_constraint}:\n{story}"""
dialog3 = """, starting with: {beginning} -> Rosey: Sure, here's a story about: {stripped_prompt}, starting with: {beginning}:\n{story}"""
dialog4 = """, ending with: {ending} -> Rosey: Sure, here's a story about {stripped_prompt}: ending with: {ending}\n{story}"""
dialog5 = """, where the middle of the story is about: {middle} -> Rosey: Sure, here's a story about: {stripped_prompt}, where the middle of the story is about: {middle}:\n{story}"""
strippedPrompt = extract_prompt_parts(prompt, PROMPT_PATTERNS)
if strippedPrompt is not None:
strippedPrompt = clear_prompt(strippedPrompt)
strippedConstraint = extract_prompt_parts(prompt, CONST_PATTERNS)
pbar = tqdm(ascii=True, desc="stories")
pbar.reset(total=len(df[df["prompt"] == prompt]))
for row in df[df["prompt"] == prompt].itertuples():
try:
story = (
row.story.replace("<newline>", "\n")
.replace("< newline >", "\n")
.replace("<new line>", "\n")
.strip()
)
dialogBeg = dialogBase.format(stripped_prompt=strippedPrompt)
dialog = dialogBeg + dialog1.format(story=story, stripped_prompt=strippedPrompt)
dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
if strippedConstraint is not None:
dialog = dialogBeg + dialog2.format(
stripped_prompt=strippedPrompt, stripped_constraint=strippedConstraint, story=story
)
dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
beginning, middles, ending = extract_story_parts(story)
if beginning is not None:
beginning, middles, ending = extract_story_parts(story)
dialog = dialogBeg + dialog3.format(
stripped_prompt=strippedPrompt, story=story, beginning=beginning
)
dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
dialog = dialogBeg + dialog4.format(stripped_prompt=strippedPrompt, story=story, ending=ending)
dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
middlesSumarizered = summarizer(middles, **params)
for middle, sumarizedMiddle in zip(middles, middlesSumarizered):
# dialogs.append(dialogBeg + dialog5.format(stripped_prompt=strippedPrompt, story=story, middle=middle))
dialog = dialogBeg + dialog5.format(
stripped_prompt=strippedPrompt, story=story, middle=sumarizedMiddle[0]["summary_text"]
)
dialogs.append(get_sample_dict(row.split, row.splitLineIndex, dialog))
pbar.update()
except Exception as e:
print(f"{row.split}/{row.splitLineIndex}")
raise e
pbar.refresh()
return dialogs
It saves parquet every step samples to avoid losing work.
## filter dataset to take only prompts with frequency greater than 20 stories.
dialogs = []
i = 0
start = 0
step = 10
for index in range(start, len(topPrompts20Reps), step):
pbar = tqdm(ascii=True, desc="prompt")
pbar.reset(total=len(topPrompts20Reps[index : index + step]))
for prompt in topPrompts20Reps[index : index + step]:
tmpDialogs = generate_instruction_diologs(prompt, ds)
if tmpDialogs is not None:
dialogs += tmpDialogs
pbar.update()
if len(dialogs) > 0:
pd.DataFrame(dialogs).to_parquet("writing-prompts-aug.parquet")
pbar.refresh()
df = pd.read_parquet("writing-prompts-aug.parquet")
for split in list(set(df.split)):
df_aux = df[df["split"] == split].iloc[:, 1:]
df_aux.reset_index(inplace=True)
df_aux.iloc[:, 1:].to_parquet(f"{split}.parquet")