A single data source with schema validation, chunked access, and preprocessing.
Uses Polars internally for lazy evaluation and efficient I/O.
Parameters:
| Name |
Type |
Description |
Default |
config
|
SourceConfig
|
Configuration describing this source.
|
required
|
Note
Rows with null entity IDs or primary temporal columns are automatically
dropped during scanning and reading. Entity IDs are cast to string
regardless of the original dtype.
Example::
source = Source(
config=SourceConfig(
name="health",
filepath=Path("data/health.parquet"),
id_col="entity_id",
temporal_cols=[
TemporalColConfig(
col_name="date", is_primary=True, drop_na=True, col_type="datetime"
)
],
categorical_cols=[
CategoricalColConfig(col_name="diagnosis", prefix="DIAG"),
CategoricalColConfig(col_name="department", prefix="DEPT"),
],
continuous_cols=[
ContinuousColConfig(
col_name="cost",
prefix="COST",
n_bins=20,
strategy="quantile",
)
],
output_format="parquet",
)
)
# lazy scan
lf = source.scan()
# eager read
df = source.read_all()
# chunked iteration
for chunk in source.iter_chunks(chunk_size=50_000):
process(chunk)
# preprocess and cache to disk
df = source.process(cache=True)
Source code in tab2seq/source/core.py
| class Source:
"""A single data source with schema validation, chunked access, and preprocessing.
Uses Polars internally for lazy evaluation and efficient I/O.
Args:
config: Configuration describing this source.
Note:
Rows with null entity IDs or primary temporal columns are automatically
dropped during scanning and reading. Entity IDs are cast to string
regardless of the original dtype.
Example::
source = Source(
config=SourceConfig(
name="health",
filepath=Path("data/health.parquet"),
id_col="entity_id",
temporal_cols=[
TemporalColConfig(
col_name="date", is_primary=True, drop_na=True, col_type="datetime"
)
],
categorical_cols=[
CategoricalColConfig(col_name="diagnosis", prefix="DIAG"),
CategoricalColConfig(col_name="department", prefix="DEPT"),
],
continuous_cols=[
ContinuousColConfig(
col_name="cost",
prefix="COST",
n_bins=20,
strategy="quantile",
)
],
output_format="parquet",
)
)
# lazy scan
lf = source.scan()
# eager read
df = source.read_all()
# chunked iteration
for chunk in source.iter_chunks(chunk_size=50_000):
process(chunk)
# preprocess and cache to disk
df = source.process(cache=True)
"""
def __init__(self, config: SourceConfig) -> None:
"""Initialize Source with configuration.
Args:
config: SourceConfig object describing this source.
"""
if not isinstance(config, SourceConfig):
raise TypeError(f"Expected SourceConfig, got {type(config).__name__}")
self.config = config
def __repr__(self) -> str:
"""Return string representation of Source."""
return f"Source(name={self.config.name!r}, path={self.config.filepath!r})"
@property
def name(self) -> str:
"""Shortcut to ``self.config.name``."""
return self.config.name
@property
def columns(self) -> list[str]:
"""All columns to select from the source file."""
return self.config.cols
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
def validate_schema(self, df: pl.DataFrame | pl.LazyFrame) -> None:
"""Check that expected columns are present.
Args:
df: A DataFrame or LazyFrame from this source.
Raises:
SchemaError: If any required column is missing.
"""
columns = (
df.collect_schema().names() if isinstance(df, pl.LazyFrame) else df.columns
)
available = set(columns)
expected = set(self.columns)
missing = expected - available
if missing:
msg = f"Source '{self.name}' is missing columns: {missing}"
raise SchemaError(msg)
# ------------------------------------------------------------------
# Reading
# ------------------------------------------------------------------
def scan(self) -> pl.LazyFrame:
"""Return a lazy scan of the source file.
This is the most idiomatic Polars entry point. No data is
read until ``.collect()`` is called.
Note:
Rows with null entity IDs are automatically dropped. For sequence columns
columns with ``drop_na=True``, rows with null sequence values are also dropped.
Entity IDs are cast to string regardless of the original dtype.
Returns:
LazyFrame selecting only the relevant columns, sorted by entity ID
and timestamp/temporal (if defined).
"""
lf = self._scan_raw()
self.validate_schema(lf)
ts_col_names = [col.col_name for col in (self.config.temporal_cols or [])]
drop_na_cols = [
col.col_name for col in (self.config.temporal_cols or []) if col.drop_na
]
lf = (
lf.select(self.columns)
.cast({self.config.id_col: pl.Utf8})
.drop_nulls(subset=[self.config.id_col, *drop_na_cols])
)
sort_by = [self.config.id_col, *ts_col_names]
return lf.sort(sort_by)
def read_all(self) -> pl.DataFrame:
"""Eagerly read the entire source into a DataFrame.
Prefer :meth:`scan` or :meth:`iter_chunks` for large files.
Returns:
Full DataFrame for this source.
"""
return self.scan().collect()
def iter_chunks(self, chunk_size: int = 100_000) -> Iterator[pl.DataFrame]:
"""Yield DataFrames of approximately ``chunk_size`` rows.
Args:
chunk_size: Maximum number of rows per chunk.
Yields:
Polars DataFrames of up to ``chunk_size`` rows.
"""
df = self.read_all()
n_rows = df.height
for offset in range(0, n_rows, chunk_size):
yield df.slice(offset, chunk_size)
def get_entity_ids(self) -> set[str]:
"""Collect all unique entity IDs from this source.
Note:
Rows with null entity IDs or timestamps/temporal values (where ``drop_na=True``) are
excluded. Entity IDs are cast to string regardless of the original dtype.
Returns:
Set of unique entity identifiers.
"""
ids = (
self.scan()
.select(self.config.id_col)
.unique()
.collect()
.get_column(self.config.id_col)
.to_list()
)
logger.info("Source '%s': found %d unique entities", self.name, len(ids))
return set(ids)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _scan_raw(self) -> pl.LazyFrame:
"""Return a raw lazy scan without column selection or sorting."""
filepath = Path(self.config.filepath)
if filepath.suffix == ".parquet":
return pl.scan_parquet(filepath)
if filepath.suffix == ".csv":
return pl.scan_csv(filepath, has_header=True)
msg = f"Unsupported file format: {filepath.suffix}"
raise ValueError(msg)
def _apply_preprocessing(self) -> pl.DataFrame:
"""Read source data and apply all preprocessing transformations."""
df = self.read_all()
# Drop nulls for configured columns (categorical and continuous)
drop_na_cols = [
col.col_name
for col in (self.config.categorical_cols or [])
+ (self.config.continuous_cols or [])
if col.drop_na
]
if drop_na_cols:
df = df.drop_nulls(subset=drop_na_cols)
# Cast categorical columns to string and prepend prefix (e.g. "I21.9" -> "DIAG_I21.9")
for col_cfg in self.config.categorical_cols or []:
df = df.with_columns(
(
pl.lit(col_cfg.prefix + "_")
+ pl.col(col_cfg.col_name).cast(pl.Utf8)
).alias(col_cfg.col_name)
)
# df = df.rename({self.config.id_col: "entity_id"})
return df
def _config_hash(self) -> str:
"""Return a short hash of the configuration for cache key generation."""
config_json = self.config.model_dump_json()
return hashlib.sha256(config_json.encode()).hexdigest()[:16]
|
name
property
Shortcut to self.config.name.
columns
property
All columns to select from the source file.
__init__
__init__(config: SourceConfig) -> None
Initialize Source with configuration.
Parameters:
| Name |
Type |
Description |
Default |
config
|
SourceConfig
|
SourceConfig object describing this source.
|
required
|
Source code in tab2seq/source/core.py
| def __init__(self, config: SourceConfig) -> None:
"""Initialize Source with configuration.
Args:
config: SourceConfig object describing this source.
"""
if not isinstance(config, SourceConfig):
raise TypeError(f"Expected SourceConfig, got {type(config).__name__}")
self.config = config
|
__repr__
Return string representation of Source.
Source code in tab2seq/source/core.py
| def __repr__(self) -> str:
"""Return string representation of Source."""
return f"Source(name={self.config.name!r}, path={self.config.filepath!r})"
|
validate_schema
validate_schema(df: DataFrame | LazyFrame) -> None
Check that expected columns are present.
Parameters:
| Name |
Type |
Description |
Default |
df
|
DataFrame | LazyFrame
|
A DataFrame or LazyFrame from this source.
|
required
|
Raises:
| Type |
Description |
SchemaError
|
If any required column is missing.
|
Source code in tab2seq/source/core.py
| def validate_schema(self, df: pl.DataFrame | pl.LazyFrame) -> None:
"""Check that expected columns are present.
Args:
df: A DataFrame or LazyFrame from this source.
Raises:
SchemaError: If any required column is missing.
"""
columns = (
df.collect_schema().names() if isinstance(df, pl.LazyFrame) else df.columns
)
available = set(columns)
expected = set(self.columns)
missing = expected - available
if missing:
msg = f"Source '{self.name}' is missing columns: {missing}"
raise SchemaError(msg)
|
scan
Return a lazy scan of the source file.
This is the most idiomatic Polars entry point. No data is
read until .collect() is called.
Note
Rows with null entity IDs are automatically dropped. For sequence columns
columns with drop_na=True, rows with null sequence values are also dropped.
Entity IDs are cast to string regardless of the original dtype.
Returns:
| Type |
Description |
LazyFrame
|
LazyFrame selecting only the relevant columns, sorted by entity ID
|
LazyFrame
|
and timestamp/temporal (if defined).
|
Source code in tab2seq/source/core.py
| def scan(self) -> pl.LazyFrame:
"""Return a lazy scan of the source file.
This is the most idiomatic Polars entry point. No data is
read until ``.collect()`` is called.
Note:
Rows with null entity IDs are automatically dropped. For sequence columns
columns with ``drop_na=True``, rows with null sequence values are also dropped.
Entity IDs are cast to string regardless of the original dtype.
Returns:
LazyFrame selecting only the relevant columns, sorted by entity ID
and timestamp/temporal (if defined).
"""
lf = self._scan_raw()
self.validate_schema(lf)
ts_col_names = [col.col_name for col in (self.config.temporal_cols or [])]
drop_na_cols = [
col.col_name for col in (self.config.temporal_cols or []) if col.drop_na
]
lf = (
lf.select(self.columns)
.cast({self.config.id_col: pl.Utf8})
.drop_nulls(subset=[self.config.id_col, *drop_na_cols])
)
sort_by = [self.config.id_col, *ts_col_names]
return lf.sort(sort_by)
|
read_all
read_all() -> pl.DataFrame
Eagerly read the entire source into a DataFrame.
Prefer :meth:scan or :meth:iter_chunks for large files.
Returns:
| Type |
Description |
DataFrame
|
Full DataFrame for this source.
|
Source code in tab2seq/source/core.py
| def read_all(self) -> pl.DataFrame:
"""Eagerly read the entire source into a DataFrame.
Prefer :meth:`scan` or :meth:`iter_chunks` for large files.
Returns:
Full DataFrame for this source.
"""
return self.scan().collect()
|
iter_chunks
iter_chunks(chunk_size: int = 100000) -> Iterator[pl.DataFrame]
Yield DataFrames of approximately chunk_size rows.
Parameters:
| Name |
Type |
Description |
Default |
chunk_size
|
int
|
Maximum number of rows per chunk.
|
100000
|
Yields:
| Type |
Description |
DataFrame
|
Polars DataFrames of up to chunk_size rows.
|
Source code in tab2seq/source/core.py
| def iter_chunks(self, chunk_size: int = 100_000) -> Iterator[pl.DataFrame]:
"""Yield DataFrames of approximately ``chunk_size`` rows.
Args:
chunk_size: Maximum number of rows per chunk.
Yields:
Polars DataFrames of up to ``chunk_size`` rows.
"""
df = self.read_all()
n_rows = df.height
for offset in range(0, n_rows, chunk_size):
yield df.slice(offset, chunk_size)
|
get_entity_ids
get_entity_ids() -> set[str]
Collect all unique entity IDs from this source.
Note
Rows with null entity IDs or timestamps/temporal values (where drop_na=True) are
excluded. Entity IDs are cast to string regardless of the original dtype.
Returns:
| Type |
Description |
set[str]
|
Set of unique entity identifiers.
|
Source code in tab2seq/source/core.py
| def get_entity_ids(self) -> set[str]:
"""Collect all unique entity IDs from this source.
Note:
Rows with null entity IDs or timestamps/temporal values (where ``drop_na=True``) are
excluded. Entity IDs are cast to string regardless of the original dtype.
Returns:
Set of unique entity identifiers.
"""
ids = (
self.scan()
.select(self.config.id_col)
.unique()
.collect()
.get_column(self.config.id_col)
.to_list()
)
logger.info("Source '%s': found %d unique entities", self.name, len(ids))
return set(ids)
|