Back to Spark

Chapter 3: Function Junction - Data manipulation with PySpark

python/docs/source/user_guide/dataprep.ipynb

4.1.18.0 KB
Original Source

Chapter 3: Function Junction - Data manipulation with PySpark

Clean data

In data science, garbage in, garbage out (GIGO) is the concept that flawed, biased or poor quality information or input produces a result or output of similar quality. To improve the analysis quality, we need data cleaning, the process to turn garbage into gold, it is composed of identifying, correcting, or removing errors and inconsistencies in data to improve its quality and usability.

Let's start with a Dataframe containing bad values:

python
!pip install pyspark==4.0.0.dev2
python
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Data Loading and Storage Example") \
    .getOrCreate()
python
from pyspark.sql import Row

df = spark.createDataFrame([
    Row(age=10, height=80.0, NAME="Alice"),
    Row(age=10, height=80.0, NAME="Alice"),
    Row(age=5, height=float("nan"), NAME="BOB"),
    Row(age=None, height=None, NAME="Tom"),
    Row(age=None, height=float("nan"), NAME=None),
    Row(age=9, height=78.9, NAME="josh"),
    Row(age=18, height=1802.3, NAME="bush"),
    Row(age=7, height=75.3, NAME="jerry"),
])

df.show()

Rename columns

At first glance, we find that column NAME is upper case. For consistency, we can use DataFrame.withColumnRenamed to rename columns.

python
df2 = df.withColumnRenamed("NAME", "name")

df2.show()

Drop null values

Then we can notice that there are two kinds of missing data:

  • the NULL values in all three columns;
  • the NaN values which means Not a Number for a numeric column;

The records without a valid name are likely useless, so let's drop them first. There are a group of functions in DataFrameNaFunctions for missing value handling, we can use DataFrame.na.drop or DataFrame.dropna to omit rows with NULL or NaN values.

After the step df2.na.drop(subset="name"), invalid record (age=None, height=NaN, name=None) is discarded.

python
df3 = df2.na.drop(subset="name")

df3.show()

Fill values

For the remaining missing values, we can use DataFrame.na.fill or DataFrame.fillna to fill them.

With a Dict input {'age': 10, 'height': 80.1}, we can specify the values for columns age and height together.

python
df4 = df3.na.fill({'age': 10, 'height': 80.1})

df4.show()

Remove outliers

After above steps, all missing values are dropped or filled. However, we can find that height=1802.3 seems unreasonable, to remove this kind of outliers, we can filter the DataFrame with a valid range like (65, 85).

python
df5 = df4.where(df4.height.between(65, 85))

df5.show()

Remove duplicates

Now, all invalid records have been handled. But we notice that record (age=10, height=80.0, name=Alice) has been duplicated. To remove such duplicates, we can simply apply DataFrame.distinct.

python
df6 = df5.distinct()

df6.show()

String manipulation

Column name contains both lower case and upper case letters. We can apply lower() function to convert all letters to lower case.

python
from pyspark.sql import functions as sf

df7 = df6.withColumn("name", sf.lower("name"))
df7.show()

For more complicated string manipulations, we can also use udf to utilize Python's power functions.

python
from pyspark.sql import functions as sf

capitalize = sf.udf(lambda s: s.capitalize())

df8 = df6.withColumn("name", capitalize("name"))
df8.show()

Reorder columns

After above process, the data is clean and we want to reorder the columns before saving the DataFrame to some storage. You can refer to previous chapter Load and Behold: Data loading, storage, file formats for more details.

Normally, we use DataFrame.select for this purpose.

python
df9 = df7.select("name", "age", "height")

df9.show()

Transform data

The main part of a data engineering project is transformation. We create new dataframes from old ones.

Select columns with select()

The input table may contains hundreds of columns, but for a specific project we likly are interested only in a small subset of them.

python
from pyspark.sql import functions as sf
df = spark.range(10)

for i in range(20):
  df = df.withColumn(f"col_{i}", sf.lit(i))

df.show()

We create a DataFrame with 21 columns via a for loop, then we only select 4 columns by select. Columns id, col_2 and col_3 are directly selected from previous DataFrame, while column sqrt_col_4_plus_5 is generated by the math functions.

We have hundreds of functions for column manipulation in pyspark.sql.function and pyspark.sql.Column.

python

df2 = df.select("id", "col_2", "col_3", sf.sqrt(sf.col("col_4") + sf.col("col_5")).alias("sqrt_col_4_plus_5"))

df2.show()

Filter rows with where()

The input table may be super huge and contains billions of rows, and we may also be interested in only a small subset.

We can use where or filter with sepcified conditions to filter the rows.

For example, we can select rows with odd id values.

python
df3 = df2.where(sf.col("id") % 2 == 1)

df3.show()

Summarizing data

In data analysis, we normally end up with summarizing data to a chart or table.

python
from pyspark.sql import Row

df = spark.createDataFrame([
    Row(incomes=[123.0, 456.0, 789.0], NAME="Alice"),
    Row(incomes=[234.0, 567.0], NAME="BOB"),
    Row(incomes=[100.0, 200.0, 100.0], NAME="Tom"),
    Row(incomes=[79.0, 128.0], NAME="josh"),
    Row(incomes=[123.0, 145.0, 178.0], NAME="bush"),
    Row(incomes=[111.0, 187.0, 451.0, 188.0, 199.0], NAME="jerry"),
])

df.show()

For example, given the income per month, we want to find the average income for each name.

python
from pyspark.sql import functions as sf

df2 = df.select(sf.lower("NAME").alias("name"), "incomes")

df2.show(truncate=False)

Reshape data using explode()

To make the data easier for aggregation, we can use explode() function to reshape the data

python
df3 = df2.select("name", sf.explode("incomes").alias("income"))

df3.show()

Summarizing via groupBy() and agg()

Then we normally use DataFrame.groupBy(...).agg(...) to aggreate the data. To compute the average income, we can apply aggration function avg

python
df4 = df3.groupBy("name").agg(sf.avg("income").alias("avg_income"))

df4.show()

Orderby

For final analysis, we normally want to order the data. In this case, we can order the data by name.

python
df5 = df4.orderBy("name")

df5.show()

When DataFrames Collide: The Art of Joining

When dealing with multiple dataframe, we likely need to combine them together in some way. The most frequently used approach is joining.

For example, given the incomes data and height data, we can use DataFrame.join to join them together by name.

We can see that only alice, josh and bush are in the final results, because they appear in both DataFrames.

python
from pyspark.sql import Row

df1 = spark.createDataFrame([
    Row(age=10, height=80.0, name="alice"),
    Row(age=9, height=78.9, name="josh"),
    Row(age=18, height=82.3, name="bush"),
    Row(age=7, height=75.3, name="tom"),
])

df2 = spark.createDataFrame([
    Row(incomes=[123.0, 456.0, 789.0], name="alice"),
    Row(incomes=[234.0, 567.0], name="bob"),
    Row(incomes=[79.0, 128.0], name="josh"),
    Row(incomes=[123.0, 145.0, 178.0], name="bush"),
    Row(incomes=[111.0, 187.0, 451.0, 188.0, 199.0], name="jerry"),
])
python
df3 = df1.join(df2, on="name")

df3.show(truncate=False)

There are seven join methods:

  • INNER
  • LEFT
  • RIGHT
  • FULL
  • CROSS
  • LEFTSEMI
  • LEFTANTI

And the default one is INNER.

Let's take LEFT join as another example. A left join includes all of the records from the first (left) of two tables, even if there are no matching values for records in the second (right) table.

python
df4 = df1.join(df2, on="name", how="left")

df4.show(truncate=False)

And a RIGHT join keeps all of the records from the right table.

python
df5 = df1.join(df2, on="name", how="right")

df5.show(truncate=False)