4.3. map_partitions
#
除了 章节 4.4 中提到的一些需要通信的计算外,有一种最简单的并行方式,英文术语为 Embarrassingly Parallel,中文可翻译为易并行。它指的是该类计算不需要太多跨 Worker 的协调和通信。比如,对某个字段加一,每个 Worker 内执行加法操作即可,Worker 之间没有通信的开销。Dask DataFrame 中可以使用 map_partitions(func)
来做这类 Embarrassingly Parallel 的操作。map_partitions(func)
的参数是一个 func
,这个 func
将在每个 pandas DataFrame 上执行,func
内可以使用 pandas DataFrame 的各类操作。如 图 4.4 所示,map_partitions(func)
对原来的 pandas DataFrame 进行了转换操作。
案例:纽约出租车数据#
我们使用纽约出租车数据集进行简单的数据预处理:计算每个订单的时长。原数据集中,tpep_pickup_datetime
和 tpep_dropoff_datetime
分别为乘客上车和下车时间,现在只需要将下车时间 tpep_dropoff_datetime
减去上车时间 tpep_pickup_datetime
。这个计算没有跨 Worker 的通信开销,因此是一种 Embarrassingly Parallel 的典型应用场景。
import sys
sys.path.append("..")
from utils import nyc_taxi
import pandas as pd
import dask
dask.config.set({'dataframe.query-planning': False})
import dask.dataframe as dd
import pandas as pd
from dask.distributed import LocalCluster, Client
cluster = LocalCluster()
client = Client(cluster)
dataset_path = nyc_taxi()
ddf = dd.read_parquet(dataset_path)
def transform(df):
df["trip_duration"] = (df["tpep_dropoff_datetime"] - df["tpep_pickup_datetime"]).dt.seconds
# 将 `trip_duration` 挪到前面
dur_column = df.pop('trip_duration')
df.insert(1, dur_column.name, dur_column)
return df
ddf = ddf.map_partitions(transform)
ddf.compute()
ddf.head(5)
VendorID | trip_duration | tpep_pickup_datetime | tpep_dropoff_datetime | passenger_count | trip_distance | RatecodeID | store_and_fwd_flag | PULocationID | DOLocationID | payment_type | fare_amount | extra | mta_tax | tip_amount | tolls_amount | improvement_surcharge | total_amount | congestion_surcharge | Airport_fee | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1253 | 2023-06-01 00:08:48 | 2023-06-01 00:29:41 | 1.0 | 3.40 | 1.0 | N | 140 | 238 | 1 | 21.9 | 3.50 | 0.5 | 6.70 | 0.0 | 1.0 | 33.60 | 2.5 | 0.00 |
1 | 1 | 614 | 2023-06-01 00:15:04 | 2023-06-01 00:25:18 | 0.0 | 3.40 | 1.0 | N | 50 | 151 | 1 | 15.6 | 3.50 | 0.5 | 3.00 | 0.0 | 1.0 | 23.60 | 2.5 | 0.00 |
2 | 1 | 1123 | 2023-06-01 00:48:24 | 2023-06-01 01:07:07 | 1.0 | 10.20 | 1.0 | N | 138 | 97 | 1 | 40.8 | 7.75 | 0.5 | 10.00 | 0.0 | 1.0 | 60.05 | 0.0 | 1.75 |
3 | 2 | 1406 | 2023-06-01 00:54:03 | 2023-06-01 01:17:29 | 3.0 | 9.83 | 1.0 | N | 100 | 244 | 1 | 39.4 | 1.00 | 0.5 | 8.88 | 0.0 | 1.0 | 53.28 | 2.5 | 0.00 |
4 | 2 | 514 | 2023-06-01 00:18:44 | 2023-06-01 00:27:18 | 1.0 | 1.17 | 1.0 | N | 137 | 234 | 1 | 9.3 | 1.00 | 0.5 | 0.72 | 0.0 | 1.0 | 15.02 | 2.5 | 0.00 |
Dask DataFrame 的某些 API 是 Embarrassingly Parallel 的,它的底层就是使用 map_partitions()
实现的。
章节 4.2 提到过,Dask DataFrame 会在某个列(索引列)上进行切分,但如果 map_partitions()
对这些索引列做了改动,需要 clear_divisions()
或者重新 set_index()
。
ddf.clear_divisions()
VendorID | trip_duration | tpep_pickup_datetime | tpep_dropoff_datetime | passenger_count | trip_distance | RatecodeID | store_and_fwd_flag | PULocationID | DOLocationID | payment_type | fare_amount | extra | mta_tax | tip_amount | tolls_amount | improvement_surcharge | total_amount | congestion_surcharge | Airport_fee | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
npartitions=1 | ||||||||||||||||||||
int32 | int32 | datetime64[us] | datetime64[us] | int64 | float64 | int64 | string | int32 | int32 | int64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
client.shutdown()