Skip to content

Quick Start

The full pipeline from raw data to model-ready sequences in six steps.

1. Generate Synthetic Data

from tab2seq.datasets import generate_synthetic_data
import polars as pl

data_paths = generate_synthetic_data(
    output_dir="synthetic_data",
    n_entities=10_000,
    seed=742,
    registries=["health", "labour", "survey", "income"],
)
pl.read_parquet(data_paths["health"]).head()
shape: (5, 7)
┌───────────┬────────────┬───────────┬───────────┬──────────────────┬─────────┬────────────────┐
│ entity_id ┆ date       ┆ diagnosis ┆ procedure ┆ department       ┆ cost    ┆ length_of_stay │
│ str       ┆ date       ┆ str       ┆ str       ┆ str              ┆ f64     ┆ i64            │
╞═══════════╪════════════╪═══════════╪═══════════╪══════════════════╪═════════╪════════════════╡
│ E00001    ┆ 2016-09-15 ┆ J18.1     ┆ CABG      ┆ gastroenterology ┆ 7306.17 ┆ 2              │
│ E00001    ┆ 2017-05-25 ┆ E78.0     ┆ XRAY      ┆ neurology        ┆  138.65 ┆ 1              │
│ E00001    ┆ 2018-01-18 ┆ E78.0     ┆ MRI       ┆ general_surgery  ┆ 6704.59 ┆ 10             │
└───────────┴────────────┴───────────┴───────────┴──────────────────┴─────────┴────────────────┘

2. Define Sources

Each Source describes one event table: its file path, ID column, timestamp, and feature columns.

Define one by one

## Option 1: Source(SourceConfig)
## Example with the health data
source_H = Source(config=SourceConfig(
    name="health",
    filepath="synthetic_data/health.parquet",
    id_col="entity_id",
    categorical_cols=[
        CategoricalColConfig(col_name="diagnosis", prefix="DIAG"),
        CategoricalColConfig(col_name="procedure", prefix="PROC"),
        CategoricalColConfig(col_name="department", prefix="DEPT"),
    ],
    continuous_cols=[
        ContinuousColConfig(col_name="cost", prefix="COST", n_bins=20, strategy="quantile"),
        ContinuousColConfig(col_name="length_of_stay", prefix="LOS", n_bins=20, strategy="quantile"),
    ],
    temporal_cols=[
        TemporalColConfig(col_name="date", is_primary=True, drop_na=True, col_type="datetime")
    ],
    output_format="parquet",
))

## Option 2: SourceConfig -> Source
## Example with the labor data
config_L = SourceConfig(
    name="labour",
    filepath="synthetic_data/labour.parquet",
    id_col="entity_id",
    categorical_cols=[
        CategoricalColConfig(col_name="status", prefix="STATUS"),
        CategoricalColConfig(col_name="occupation", prefix="OCC"),
        CategoricalColConfig(col_name="residence_region", prefix="REGION"),
        CategoricalColConfig(col_name="native_language", prefix="LANG", static=True),
    ],
    continuous_cols=[
        ContinuousColConfig(col_name="weekly_hours", prefix="WEEKLY_HOURS")
    ],
    temporal_cols=[
        TemporalColConfig(col_name="date", is_primary=True, drop_na=True, col_type="datetime"),
        TemporalColConfig(col_name="birthday", is_primary=False, static=True, drop_na=True, col_type="datetime"),
    ],
    output_format="parquet",
)
source_L = Source(config=config_L)

You can then pass the sources to the Cohort object as a list [source_H, source_L].

Define via Source Collection

from tab2seq.source import (
    Source, SourceCollection, SourceConfig,
    CategoricalColConfig, ContinuousColConfig, TemporalColConfig,
)

configs = [
    SourceConfig(
        name="health",
        filepath="synthetic_data/health.parquet",
        id_col="entity_id",
        categorical_cols=[
            CategoricalColConfig(col_name="diagnosis", prefix="DIAG"),
            CategoricalColConfig(col_name="procedure", prefix="PROC"),
            CategoricalColConfig(col_name="department", prefix="DEPT"),
        ],
        continuous_cols=[
            ContinuousColConfig(col_name="cost", prefix="COST", n_bins=20, strategy="quantile"),
            ContinuousColConfig(col_name="length_of_stay", prefix="LOS", n_bins=10, strategy="quantile"),
        ],
        temporal_cols=[
            TemporalColConfig(col_name="date", is_primary=True, drop_na=True, col_type="datetime"),
        ],
    ),
    SourceConfig(
        name="labour",
        filepath="synthetic_data/labour.parquet",
        id_col="entity_id",
        categorical_cols=[
            CategoricalColConfig(col_name="status", prefix="STATUS"),
            CategoricalColConfig(col_name="occupation", prefix="OCC"),
            CategoricalColConfig(col_name="residence_region", prefix="REGION"),
            CategoricalColConfig(col_name="native_language", prefix="LANG", static=True),
        ],
        continuous_cols=[
            ContinuousColConfig(col_name="weekly_hours", prefix="WEEKLY_HOURS", n_bins=10, strategy="uniform"),
        ],
        temporal_cols=[
            TemporalColConfig(col_name="date", is_primary=True, drop_na=True, col_type="datetime"),
            TemporalColConfig(col_name="birthday", static=True, drop_na=True, col_type="datetime"),
        ],
    ),
]

collection = SourceCollection.from_configs(configs)

for source in collection:
    print(f"{source.name}: {len(source.get_entity_ids())} entities")

Tip

Columns marked static=True are carried through to the cohort split table as entity-level attributes (e.g. birthday, native language). See Sources for details.

3. Build a Cohort and Splits

from tab2seq.cohort import Cohort, CohortConfig, EntityInclusionCriteria

criteria = [
    EntityInclusionCriteria(source_name="labour", required=True, min_events=1),
    EntityInclusionCriteria(source_name="income", required=True, min_events=1),
]

cohort = Cohort(
    name="my_cohort",
    sources=collection,
    inclusion_criteria=criteria,
    cache_dir="data/cohorts",
)

cohort.build_entities_table(force_recompute=True)
split_cfg = CohortConfig(train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=42)
cohort.build_or_load_splits(split_cfg)
print(f"Cohort size: {len(cohort)} entities")

Only required=True criteria filter entities. Optional criteria are allowed for metadata and validation, but min_events and max_events on them are ignored with a warning.

4. Fit a Vocabulary (Train Split Only)

from tab2seq.tokenization import Tokenizer, Vocabulary, VocabularyConfig

vocab = Vocabulary(
    config=VocabularyConfig(
        max_vocab_size=50_000,
        min_token_count=5,
        extra_tokens=["[DEATH]", "[RETIRED]"],
    )
)
vocab_df = vocab.fit_from_cohort_train(cohort=cohort, split_config=split_cfg)
print(f"Vocabulary size: {vocab_df.height}")

5. Build and Persist Tokenized Event Datasets

from tab2seq.datasets import EventDataset, EventDatasetConfig, RelativeDateRule

dataset = EventDataset(
    cohort=cohort,
    tokenizer=Tokenizer(vocab),
    dataset_config=EventDatasetConfig(
        reference_date="1970-01-01",
        threshold_date="2021-01-01",
        include_after_threshold=True,
        include_token_str=True,
        embed_static_in_events=False,
        relative_date_features=[
            RelativeDateRule(
                source_static_column="labour__birthday",
                output_column="age_years",
                unit="years",
                floor_int=True,
            ),
        ],
    ),
)

artifacts = dataset.write_parquet(dataset_name="my_dataset_v1", force_write=True)
print(artifacts.dataset_dir)

6. Load and Read Records

dataset_loaded = EventDataset.from_name(
    name="my_dataset_v1",
    registry_dir=cohort.cache_dir / "datasets",
)

record = dataset_loaded.sample_entity_record(split="train", seed=7)

See Record Formats for all four output formats (raw, frame, tensor, padded_tensor).