map_partitions

4.3. map_partitions#

除了 章节 4.4 中提到的一些需要通信的计算外,有一种最简单的并行方式,英文术语为 Embarrassingly Parallel,中文可翻译为易并行。它指的是该类计算不需要太多跨 Worker 的协调和通信。比如,对某个字段加一,每个 Worker 内执行加法操作即可,Worker 之间没有通信的开销。Dask DataFrame 中可以使用 map_partitions(func) 来做这类 Embarrassingly Parallel 的操作。map_partitions(func) 的参数是一个 func,这个 func 将在每个 pandas DataFrame 上执行,func 内可以使用 pandas DataFrame 的各类操作。如 图 4.4 所示,map_partitions(func) 对原来的 pandas DataFrame 进行了转换操作。

../_images/map-partitions.svg

图 4.4 map_partitions()#

案例:纽约出租车数据#

我们使用纽约出租车数据集进行简单的数据预处理:计算每个订单的时长。原数据集中,tpep_pickup_datetimetpep_dropoff_datetime 分别为乘客上车和下车时间,现在只需要将下车时间 tpep_dropoff_datetime 减去上车时间 tpep_pickup_datetime。这个计算没有跨 Worker 的通信开销,因此是一种 Embarrassingly Parallel 的典型应用场景。

import sys
sys.path.append("..")
from utils import nyc_taxi

import pandas as pd
import dask
dask.config.set({'dataframe.query-planning': False})
import dask.dataframe as dd
import pandas as pd
from dask.distributed import LocalCluster, Client

cluster = LocalCluster()
client = Client(cluster)
dataset_path = nyc_taxi()
ddf = dd.read_parquet(dataset_path)
def transform(df):
    df["trip_duration"] = (df["tpep_dropoff_datetime"] - df["tpep_pickup_datetime"]).dt.seconds
    # 将 `trip_duration` 挪到前面
    dur_column = df.pop('trip_duration')
    df.insert(1, dur_column.name, dur_column)
    return df

ddf = ddf.map_partitions(transform)
ddf.compute()
ddf.head(5)
VendorID trip_duration tpep_pickup_datetime tpep_dropoff_datetime passenger_count trip_distance RatecodeID store_and_fwd_flag PULocationID DOLocationID payment_type fare_amount extra mta_tax tip_amount tolls_amount improvement_surcharge total_amount congestion_surcharge Airport_fee
0 1 1253 2023-06-01 00:08:48 2023-06-01 00:29:41 1.0 3.40 1.0 N 140 238 1 21.9 3.50 0.5 6.70 0.0 1.0 33.60 2.5 0.00
1 1 614 2023-06-01 00:15:04 2023-06-01 00:25:18 0.0 3.40 1.0 N 50 151 1 15.6 3.50 0.5 3.00 0.0 1.0 23.60 2.5 0.00
2 1 1123 2023-06-01 00:48:24 2023-06-01 01:07:07 1.0 10.20 1.0 N 138 97 1 40.8 7.75 0.5 10.00 0.0 1.0 60.05 0.0 1.75
3 2 1406 2023-06-01 00:54:03 2023-06-01 01:17:29 3.0 9.83 1.0 N 100 244 1 39.4 1.00 0.5 8.88 0.0 1.0 53.28 2.5 0.00
4 2 514 2023-06-01 00:18:44 2023-06-01 00:27:18 1.0 1.17 1.0 N 137 234 1 9.3 1.00 0.5 0.72 0.0 1.0 15.02 2.5 0.00

Dask DataFrame 的某些 API 是 Embarrassingly Parallel 的,它的底层就是使用 map_partitions() 实现的。

章节 4.2 提到过,Dask DataFrame 会在某个列(索引列)上进行切分,但如果 map_partitions() 对这些索引列做了改动,需要 clear_divisions() 或者重新 set_index()

ddf.clear_divisions()
Dask DataFrame Structure:
VendorID trip_duration tpep_pickup_datetime tpep_dropoff_datetime passenger_count trip_distance RatecodeID store_and_fwd_flag PULocationID DOLocationID payment_type fare_amount extra mta_tax tip_amount tolls_amount improvement_surcharge total_amount congestion_surcharge Airport_fee
npartitions=1
int32 int32 datetime64[us] datetime64[us] int64 float64 int64 string int32 int32 int64 float64 float64 float64 float64 float64 float64 float64 float64 float64
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
Dask Name: transform, 2 graph layers
client.shutdown()