EventDataset¶
EventDataset encodes all events for all cohort entities into Parquet files, partitioned by split. It computes relative-date features, handles static token embedding, and exposes four record access patterns.
Building and persisting¶
from tab2seq.datasets import EventDataset, EventDatasetConfig, RelativeDateRule
dataset = EventDataset(
cohort=cohort,
tokenizer=Tokenizer(vocab),
dataset_config=EventDatasetConfig(
reference_date="1970-01-01",
threshold_date="2021-01-01",
include_after_threshold=True,
relative_date_features=[
RelativeDateRule(
source_static_column="labour__birthday",
output_column="age_years",
unit="years",
floor_int=True,
),
],
),
)
artifacts = dataset.write_parquet(
dataset_name="my_dataset_v1",
force_write=True,
include_token_str=True, # store human-readable token strings (default True)
embed_static_in_events=False, # keep static in a separate table (default False)
)
print(artifacts.dataset_dir)
Key config options¶
| Option | Description |
|---|---|
reference_date |
Epoch for computing primary_time (days since this date per event) |
threshold_date |
Date used to flag post-threshold events |
include_after_threshold |
If True, add an after_threshold boolean column to built events |
relative_date_features |
List of RelativeDateRule for per-event derived features (e.g. age) |
Build-time options¶
include_token_str and embed_static_in_events are passed directly to build_split, build_all_splits, or write_parquet — they are not stored in the config:
| Parameter | Default | Description |
|---|---|---|
include_token_str |
True |
Store human-readable token strings alongside IDs |
embed_static_in_events |
False |
Left-join static columns into each event row (denormalised) |
# Build without token strings to reduce size
train_df = dataset.build_split("train", include_token_str=False)
# Embed static columns directly in event rows
train_df = dataset.build_split("train", embed_static_in_events=True)
Relative date features¶
RelativeDateRule computes a per-event derived value by comparing the event's timestamp to a static reference date. Unlike TemporalColConfig(static=True) (which stores a fixed date for the entity once), relative date features change per event:
RelativeDateRule(
source_static_column="labour__birthday", # static column from any source
output_column="age_years", # name added to the event row
unit="years", # "days", "weeks", "months", or "years"
floor_int=True, # floor to nearest integer
)
The result (e.g. age_years=28) appears as an extra column in every event row and is stacked into the temporal matrix (column indices 1+) in tensor formats.
Loading a saved dataset¶
dataset_loaded = EventDataset.from_name(
name="my_dataset_v1",
registry_dir=cohort.cache_dir / "datasets",
)
This reloads the dataset from Parquet without requiring the original cohort, sources, or tokenizer.
Access patterns¶
Four methods are available on any EventDataset:
# Fetch a specific entity by ID (returns None if not in that split)
record = dataset_loaded.get_entity_record("E00003", split="train")
# Random sample
record = dataset_loaded.sample_entity_record(split="train", seed=7)
# Full iterator sweep
for record in dataset_loaded.iter_entity_records(split="train", shuffle=True, seed=42):
pass
# Stateful one-at-a-time — remembers position across calls, returns None when exhausted
record = dataset_loaded.next_entity_record(split="val", shuffle=True, seed=0, reset=True)
while record is not None:
record = dataset_loaded.next_entity_record(split="val", shuffle=True, seed=0)
All four methods accept a format parameter. See Record Formats for details.
Special tokens¶
| Parameter | Default | Description |
|---|---|---|
include_cls |
True |
Prepend [CLS] to the static token sequence |
include_sep |
True |
Append [SEP] to each event's token sequence |
static_as_event |
False |
Embed static tokens as event 0 instead of a separate field |
When static_as_event=False (default), [CLS] is prepended to static_token_ids and the static token_str. When static_as_event=True, [CLS] becomes the first token of event 0 and [SEP] is appended to it as well. static_token_ids is always populated regardless of this flag.
# Static tokens kept separate (default)
record = dataset_loaded.get_entity_record("E1", split="train", include_cls=True, include_sep=True)
# record["static"]["token_ids"] → [2, 105, 86] ([CLS] + static tokens)
# record["events"][0]["token_ids"] → [98, 110, 3] (event tokens + [SEP])
# Static tokens embedded as event 0
record = dataset_loaded.get_entity_record("E1", split="train", static_as_event=True)
# record["events"][0]["source_name"] → "__static__"
# record["events"][0]["primary_time"] → 0
# record["static_token_ids"] → still populated (without [SEP])
Filtering and truncating events¶
All access methods support filtering and sequence length control:
| Parameter | Default | Description |
|---|---|---|
include_after_threshold |
False |
When False, exclude events where after_threshold=True |
censoring |
"none" |
"right" keeps earliest events; "left" keeps latest events |
max_events |
None |
Maximum number of events per entity |
max_tokens |
None |
Maximum total token count — whole events only, never partial |
max_events and max_tokens cannot be set together. When the entity is within the limit, censoring has no effect.
# Keep only the 50 most recent events
for record in dataset.iter_entity_records("train", censoring="left", max_events=50):
pass
# Limit to 512 tokens total, dropping events from the end
for record in dataset.iter_entity_records("train", censoring="right", max_tokens=512):
pass
# Include post-threshold events (requires include_after_threshold=True in config)
record = dataset.get_entity_record("E1", split="train", include_after_threshold=True)