Unlocking Ray Data: A Deep Dive into the Dataset API and Its Powerful Transformations
This comprehensive guide explains Ray Data's Dataset core type, its distributed pipeline design, lazy execution model, API groups, transformation methods, column operations, I/O integrations, metadata utilities, and execution workflow, providing clear code examples and practical usage tips.
Overview and Design Philosophy
The Dataset class is the central distributed data collection type in Ray Data, built around a pipeline that emits ObjectRef[Block] where each Block holds a data slice in Arrow format. Blocks define parallelism granularity, and transformations are lazily executed, triggering only when downstream consumption occurs.
Data Sources and Sinks
Datasets are created via read_*() functions for external storage, from_*() for in‑memory data, or range_*() for synthetic data. They are written using write_*() methods supporting formats such as Parquet, CSV, JSON, Iceberg, Mongo, BigQuery, and more.
Core Architecture
Ray Data separates execution and logical planning: _plan: ExecutionPlan – physical operators, scheduling, resources. _logical_plan: LogicalPlan – DAG used for optimizations and push‑downs.
The two plans are linked via _plan.link_logical_plan(logical_plan). The constructor is hidden; users create Datasets through the public ray.data.* API (e.g., ray.data.read_parquet(), ray.data.range()).
API Group Constants
Methods are organized into groups identified by constants such as BT_API_GROUP (basic transformations), SSR_API_GROUP (sorting, shuffling, repartitioning), SMJ_API_GROUP (splitting, merging, joining), GGA_API_GROUP (grouped and global aggregations), CD_API_GROUP (consuming data), IOC_API_GROUP (I/O and conversion), IM_API_GROUP (metadata inspection), and E_API_GROUP (execution and materialization).
Basic Transformations (BT_API_GROUP)
map(fn, ...)– apply a function to each row. map_batches(fn, batch_size=..., batch_format=..., ...) – apply a function to each batch; supports NumPy, pandas, PyArrow, and GPU batch sizing. flat_map(fn, ...) – map each row to multiple rows and flatten the result.
Example using a plain function:
import ray
from typing import Any, Dict
def add_suffix(row: Dict[str, Any]) -> Dict[str, Any]:
row["id"] = row["id"] + 100
return row
ds = ray.data.range(5)
ds = ds.map(add_suffix)
print(ds.take_all()) # [{'id': 100}, {'id': 101}, ...]Example using a callable class (stateful transformation):
import ray
from typing import Any, Dict, Optional
class OffsetAdder:
def __init__(self, offset: int, scale: int = 1):
self.offset = offset
self.scale = scale
def __call__(self, row: Dict[str, Any]) -> Dict[str, Any]:
return {"id": row["id"] * self.scale + self.offset}
ds = ray.data.range(5)
ds = ds.map(OffsetAdder, fn_constructor_args=[100], fn_constructor_kwargs={"scale": 2})
print(ds.take_all()) # [{'id': 100}, {'id': 102}, ...]Callable classes must implement __call__(self, row) so Ray can instantiate the class on each worker and invoke it like a function.
Expressions and Column Operations (Alpha)
with_column(col, expr, ...)– add a column via an expression or UDF. add_column(col, fn, ...) (deprecated) – previously used map_batches to add columns. drop_columns(cols, ...) – remove specified columns. select_columns(cols, ...) – keep only selected columns (Parquet column‑pruning supported). rename_columns(names, ...) – rename columns using a dict or list.
Filtering and Sampling
filter(fn=None, expr=None, ...)– keep rows matching a function or an ray.data.expressions.Expr. random_sample(fraction, seed=...) – randomly sample a fraction of rows.
Sorting, Shuffling, and Repartitioning (SSR_API_GROUP)
repartition(num_blocks=None, target_num_rows_per_block=None, shuffle=..., keys=..., sort=...)– change block count or rows per block, optionally shuffling or sorting. random_shuffle(seed=...) – full all‑to‑all random shuffle. randomize_block_order(seed=...) – shuffle only block order. sort(key, descending=..., boundaries=...) – sort by a column, with optional partition boundaries.
Splitting, Merging, and Joining (SMJ_API_GROUP)
streaming_split(n, equal=..., locality_hints=...)– return DataIterator for streaming splits (useful for distributed training). split(n, equal=..., locality_hints=...) – materialize and split into a list of MaterializedDataset objects. split_at_indices(indices) – split like np.split. split_proportionately(proportions) – split by given proportions (e.g., train/val).
train_test_split(test_size, shuffle=..., seed=..., stratify=...)– classic train/test split with optional stratification. union(*other) – row‑wise concatenation preserving order.
join(ds, join_type, num_partitions, on=..., right_on=..., ...)– all‑to‑all join supporting multiple join types. zip(*other) – column‑wise zip; conflicting column names get suffixes.
Limiting
limit(limit)– logically keep only the first limit rows without pulling them to the driver.
Grouping and Aggregation (GGA_API_GROUP)
groupby(key, num_partitions=...)– returns a GroupedData object for further map_groups or aggregate calls. unique(column, ignore_nulls=...) – distinct values of a column. aggregate(*aggs) – generic aggregation using AggregateFn objects. sum/min/max/mean/std(on=..., ignore_nulls=...) – single‑ or multi‑column aggregations. summary(columns=..., override_dtype_agg_mapping=...) (alpha) – statistical summary per column.
Consuming Data (CD_API_GROUP)
iterator()– returns a DataIteratorImpl for internal reuse. take(limit=20) – fetch up to limit rows to the driver. take_all(limit=...) – fetch all rows (optional limit to avoid OOM). take_batch(batch_size=..., batch_format=...) – fetch a single batch in the same format as map_batches input. show(limit=20) – print the first limit rows. iter_rows() – row‑wise iterator.
iter_batches(prefetch_batches=..., batch_size=..., batch_format=..., drop_last=..., local_shuffle_*=...)– batch iterator with prefetch and local shuffle options. iter_torch_batches(...) – convert batches to PyTorch tensors. iter_tf_batches(...) (deprecated) – use to_tf instead.
Framework Integration
to_tf(feature_columns, label_columns, ...)– convert to tf.data.Dataset.
I/O and Format Conversion (IOC_API_GROUP)
write_parquet(path, ...)– ParquetDatasink with partitioning, row‑group control. write_json(path, ...) – JSONDatasink (supports JSONL). write_csv(path, ...) – CSVDatasink using Arrow CSV writer. write_iceberg(table_identifier, ...) (alpha) – IcebergDatasink supporting append/upsert/overwrite. write_images(path, column, ...) (alpha) – ImageDatasink for column‑wise image export. write_tfrecords(path, ...) – TFRecordDatasink. write_webdataset(path, ...) (alpha) – WebDataset format. write_numpy(path, column, ...) – NumpyDatasink for single‑column .npy files. write_sql(sql, connection_factory, ...) – generic DB‑API2 sink. write_snowflake(...) – Snowflake via SQLDatasink. write_mongo(uri, database, collection, ...) (alpha) – MongoDatasink. write_bigquery(...) – BigQueryDatasink. write_clickhouse(...) – ClickHouseDatasink. write_lance(path, ...) – Lance table sink.
Conversion to Other Libraries
to_daft()– returns a daft.DataFrame. to_dask(meta=..., verify_meta=...) – returns a dask.dataframe.DataFrame. to_mars() – returns a mars.dataframe.DataFrame. to_modin() – returns a modin.pandas.DataFrame. to_spark(spark) – returns a pyspark.sql.DataFrame (requires RayDP). to_pandas(limit=...) – returns a pandas.DataFrame with optional row limit. to_tf(...) – returns a tf.data.Dataset.
Metadata and Inspection (IM_API_GROUP)
schema(fetch_if_missing=True)– returns a Schema; can infer via limit(1) if missing. columns(fetch_if_missing=True) – list of column names. count() – total rows (Parquet can read from metadata). num_blocks() – only available on MaterializedDataset. size_bytes() – memory size when known. input_files() – list of source files. stats() – execution statistics string. explain() – prints logical and physical plans after optimization.
Execution and Materialization (E_API_GROUP)
materialize()forces the entire pipeline to run, materializing Blocks in Ray's object store and returning a new MaterializedDataset without altering the original.
Developer and Internal APIs
iter_internal_ref_bundles()– iterate over RefBundle objects without materialization. get_internal_block_refs() (deprecated) – replaced by iter_internal_ref_bundles. _execute_to_iterator() – returns iterator, stats, and optional executor, also sets _current_executor. has_serializable_lineage() – indicates if the dataset lineage can be serialized. serialize_lineage() – returns lineage bytes. Dataset.deserialize_lineage(bytes) – reconstructs a dataset from serialized lineage. to_random_access_dataset(key, num_workers=...) – experimental random‑access dataset. context – current DataContext. set_name(name) / name – assign a human‑readable name for metrics. get_dataset_id() – unique identifier containing name, UUID, and execution index.
Design Highlights
Data model: Blocks are Arrow or Pandas slices stored as ObjectRef for zero‑copy sharing.
Execution model: Lazy DAG with streaming execution; consumption triggers actual computation.
Extensibility: Datasinks, compute strategies, and expression system allow adding new I/O formats and operators.
Compatibility: Seamless conversion to Pandas, Arrow, PyTorch, TensorFlow, Dask, Modin, Spark, etc.
Resource awareness: Supports num_cpus, num_gpus, memory limits, and Ray remote arguments.
Optimization: Logical plan enables push‑down (e.g., Parquet column pruning) and operator fusion.
Big Data Technology Tribe
Focused on computer science and cutting‑edge tech, we distill complex knowledge into clear, actionable insights. We track tech evolution, share industry trends and deep analysis, helping you keep learning, boost your technical edge, and ride the digital wave forward.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
