Retrieve top n in each group of a DataFrame in pyspark

Python Programming

Question or problem about Python programming:

There’s a DataFrame in pyspark with data as below:

user_id object_id score
user_1 object_1 3
user_1 object_1 1
user_1 object_2 2
user_2 object_1 5
user_2 object_2 2
user_2 object_2 6
user_id object_id score user_1 object_1 3 user_1 object_1 1 user_1 object_2 2 user_2 object_1 5 user_2 object_2 2 user_2 object_2 6
user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

What I expect is returning 2 records in each group with the same user_id, which need to have the highest score. Consequently, the result should look as the following:

user_id object_id score
user_1 object_1 3
user_1 object_2 2
user_2 object_2 6
user_2 object_1 5
user_id object_id score user_1 object_1 3 user_1 object_2 2 user_2 object_2 6 user_2 object_1 5
user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

I’m really new to pyspark, could anyone give me a code snippet or portal to the related documentation of this problem? Great thanks!

How to solve the problem:

Solution 1:

I believe you need to use window functions to attain the rank of each row based on user_id and score, and subsequently filter your results to only keep the first two values.

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col
window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
df.select('*', rank().over(window).alias('rank'))
.filter(col('rank') <= 2)
.show()
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1| 3| 1|
#| user_1| object_2| 2| 2|
#| user_2| object_2| 6| 1|
#| user_2| object_1| 5| 2|
#+-------+---------+-----+----+
from pyspark.sql.window import Window from pyspark.sql.functions import rank, col window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc()) df.select('*', rank().over(window).alias('rank')) .filter(col('rank') <= 2) .show() #+-------+---------+-----+----+ #|user_id|object_id|score|rank| #+-------+---------+-----+----+ #| user_1| object_1| 3| 1| #| user_1| object_2| 2| 2| #| user_2| object_2| 6| 1| #| user_2| object_1| 5| 2| #+-------+---------+-----+----+
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+

In general, the official programming guide is a good place to start learning Spark.

Data
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
rdd = sc.parallelize([("user_1", "object_1", 3), ("user_1", "object_2", 2), ("user_2", "object_1", 5), ("user_2", "object_2", 2), ("user_2", "object_2", 6)]) df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

Solution 2:

Top-n is more accurate if using row_number instead of rank when getting rank equality:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
.where(col('row_number') <= n) \
.limit(20) \
.toPandas()
val n = 5 df.select(col('*'), row_number().over(window).alias('row_number')) \ .where(col('row_number') <= n) \ .limit(20) \ .toPandas()
val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()


Note limit(20).toPandas() trick instead of show() for Jupyter notebooks for nicer formatting.

Solution 3:

I know the question is asked for pyspark and I was looking for the similar answer in Scala i.e.


Retrieve top n values in each group of a DataFrame in Scala

Here is the scala version of @mtoto's answer.

<em>import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.rank import org.apache.spark.sql.functions.col val window = Window.partitionBy("user_id").orderBy('score desc) val rankByScore = rank().over(window) df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() # you can change the value 2 to any number you want. Here 2 represents the top 2 values </em>
<em>import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.rank import org.apache.spark.sql.functions.col val window = Window.partitionBy("user_id").orderBy('score desc) val rankByScore = rank().over(window) df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() # you can change the value 2 to any number you want. Here 2 represents the top 2 values </em>
import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.rank import org.apache.spark.sql.functions.col val window = Window.partitionBy("user_id").orderBy('score desc) val rankByScore = rank().over(window) df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() # you can change the value 2 to any number you want. Here 2 represents the top 2 values 

More examples can be found here.

Solution 4:

with Python 3 and Spark 2.4

<em>from pyspark.sql import Window import pyspark.sql.functions as f def get_topN(df, group_by_columns, order_by_column, n=1): window_group_by_columns = Window.partitionBy(group_by_columns) ordered_df = df.select(df.columns + [ f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')]) topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank") return topN_df top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1) </em>
<em>from pyspark.sql import Window import pyspark.sql.functions as f def get_topN(df, group_by_columns, order_by_column, n=1): window_group_by_columns = Window.partitionBy(group_by_columns) ordered_df = df.select(df.columns + [ f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')]) topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank") return topN_df top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1) </em>
from pyspark.sql import Window import pyspark.sql.functions as f def get_topN(df, group_by_columns, order_by_column, n=1): window_group_by_columns = Window.partitionBy(group_by_columns) ordered_df = df.select(df.columns + [ f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')]) topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank") return topN_df top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1) 

Solution 5:

To Find Nth highest value in PYSPARK SQLquery using ROW_NUMBER() function:

<em>SELECT * FROM ( SELECT e.*, ROW_NUMBER() OVER (ORDER BY col_name DESC) rn FROM Employee e ) WHERE rn = N </em>
<em>SELECT * FROM ( SELECT e.*, ROW_NUMBER() OVER (ORDER BY col_name DESC) rn FROM Employee e ) WHERE rn = N </em>
SELECT * FROM ( SELECT e.*, ROW_NUMBER() OVER (ORDER BY col_name DESC) rn FROM Employee e ) WHERE rn = N 

N is the nth highest value required from the column

Output:

<em>[Stage 2:> (0 + 1) / 1]++++++++++++++++ +-----------+ |col_name | +-----------+ |1183395 | +-----------+ </em>
<em>[Stage 2:> (0 + 1) / 1]++++++++++++++++ +-----------+ |col_name | +-----------+ |1183395 | +-----------+ </em>
[Stage 2:> (0 + 1) / 1]++++++++++++++++ +-----------+ |col_name | +-----------+ |1183395 | +-----------+ 

query will return N highest value

Hope this helps!