Skip to content

tab2seq.cohort

tab2seq.cohort

Cohort construction, filtering, and split caching.

CohortConfig

Bases: BaseModel

Configuration for defining a cohort of entities and splitting into train/val/test sets.

Attributes:

Name Type Description
use_splits bool

Whether to split the cohort into train/val/test sets.

train_frac float

Fraction of entities to include in the training set.

val_frac float

Fraction of entities to include in the validation set.

test_frac float

Fraction of entities to include in the test set.

seed int

Random seed for reproducible splits.

stratify_col str | None

Optional column name to use for stratified splitting. Must be a static column present in the dataset.

Source code in tab2seq/cohort/config.py
class CohortConfig(BaseModel):
    """Configuration for defining a cohort of entities and splitting into train/val/test sets.

    Attributes:
        use_splits: Whether to split the cohort into train/val/test sets.
        train_frac: Fraction of entities to include in the training set.
        val_frac: Fraction of entities to include in the validation set.
        test_frac: Fraction of entities to include in the test set.
        seed: Random seed for reproducible splits.
        stratify_col: Optional column name to use for stratified splitting. 
                      Must be a static column present in the dataset.
    """

    use_splits: bool = True
    train_frac: float = Field(0.7, ge=0.0, le=1.0)
    val_frac: float = Field(0.15, ge=0.0, le=1.0)
    test_frac: float = Field(0.15, ge=0.0, le=1.0)
    seed: int = 792
    stratify_col: str | None = None

    @field_validator("stratify_col", mode="before")
    @classmethod
    def _validate_stratify_col(cls, v: str | None) -> str | None:
        if v is None:
            return None
        if not isinstance(v, str) or not v.strip():
            raise ValueError("'stratify_col' must be a non-empty string when set.")
        if v != v.strip():
            raise ValueError("'stratify_col' cannot have surrounding whitespace.")
        return v

    @model_validator(mode="after")
    def _fractions_sum_to_one(self) -> CohortConfig:
        if self.use_splits:
            total = self.train_frac + self.val_frac + self.test_frac
            if abs(total - 1.0) > 1e-12:
                msg = f"Split fractions must sum to 1.0, got {total:.4f}"
                raise ValueError(msg)
        return self

    def config_hash(self) -> str:
        """Deterministic hash of split configuration."""
        payload = json.dumps(self.model_dump(exclude_none=False), sort_keys=True)
        return hashlib.sha256(payload.encode()).hexdigest()[:16]

config_hash

config_hash() -> str

Deterministic hash of split configuration.

Source code in tab2seq/cohort/config.py
def config_hash(self) -> str:
    """Deterministic hash of split configuration."""
    payload = json.dumps(self.model_dump(exclude_none=False), sort_keys=True)
    return hashlib.sha256(payload.encode()).hexdigest()[:16]

EntityInclusionCriteria

Bases: BaseModel

Entity inclusion criteria for a single Source.

Defines requirements an entity must meet within a specific Source to be included in the cohort.

Attributes:

Name Type Description
source_name str

Name of the Source this criteria applies to.

required bool

If True, entities must appear in this Source.

min_events int | None

Minimum number of events an entity must have in this Source. Only checked when required is True.

max_events int | None

Maximum number of events an entity may have in this Source. Only checked when required is True.

Source code in tab2seq/cohort/config.py
class EntityInclusionCriteria(BaseModel):
    """Entity inclusion criteria for a single `Source`.

    Defines requirements an entity must meet within a specific
    `Source` to be included in the cohort.

    Attributes:
        source_name: Name of the `Source` this criteria applies to.
        required: If ``True``, entities must appear in this `Source`.
        min_events: Minimum number of events an entity must have
            in this `Source`. Only checked when ``required`` is ``True``.
        max_events: Maximum number of events an entity may have
            in this `Source`. Only checked when ``required`` is ``True``.

    """

    source_name: str
    required: bool = False
    min_events: int | None = None
    max_events: int | None = None

    @field_validator("source_name", mode="before")
    @classmethod
    def _no_whitespace_string(cls, v: str, info: Any) -> str:
        return validate_no_whitespace_string(v, info)

    @model_validator(mode="after")
    def _validate_event_bounds(self) -> EntityInclusionCriteria:
        if not self.required and (
            self.min_events is not None or self.max_events is not None
        ):
            warnings.warn(
                (
                    f"EntityInclusionCriteria for source '{self.source_name}' has "
                    "required=False, so min_events and max_events are ignored."
                ),
                UserWarning,
                stacklevel=2,
            )
            return self

        if self.required:
            if self.min_events is None or self.min_events < 1:
                raise ValueError(
                    "If 'required' is True, 'min_events' must be a positive integer."
                )
            if self.max_events is not None and self.max_events < self.min_events:
                raise ValueError(
                    f"'max_events' ({self.max_events}) cannot be less than 'min_events' ({self.min_events})."
                )
        return self

Cohort

Cohort unites the Source objects to create a unified entity set and filter, split, and cache it for modeling.

Source code in tab2seq/cohort/core.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
class Cohort:
    """Cohort unites the `Source` objects to create a unified entity set and filter, split, and cache it for modeling.
    """

    def __init__(
        self,
        name: str,
        sources: Source | list[Source] | SourceCollection,
        inclusion_criteria: list[EntityInclusionCriteria] | None = None,
        cache_dir: str | Path = Path("data/cohorts/"),
        use_cache: bool = True,
    ) -> None:
        """Initialize a Cohort with given sources and configuration.
        Args:
            name: Unique name for this cohort, used in caching and metadata.
            sources: One or more `Source` objects or a `SourceCollection` defining the data sources for this cohort.
            inclusion_criteria: Optional list of `EntityInclusionCriteria` to filter entities.
            cache_dir: Base directory for caching cohort artifacts. If None, caching is disabled.
            use_cache: Whether to enable caching for this cohort. If False, no caching will occur
                even if `cache_dir` is provided.
        Raises:
            ValueError: If `name` is empty or contains only whitespace, or if `inclusion_criteria` contains invalid entries.
            KeyError: If `inclusion_criteria` references a source not in the collection.
            TypeError: If `sources` is not a valid type (Source, list[Source], or SourceCollection).
        """
        if not isinstance(name, str) or not name.strip():
            raise ValueError("'name' must be a non-empty string.")
        if name != name.strip():
            raise ValueError("'name' cannot have leading or trailing whitespace.")

        self._name = name

        if isinstance(sources, Source):
            self._collection = SourceCollection([sources])
        elif isinstance(sources, list):
            self._collection = SourceCollection(sources)
        elif isinstance(sources, SourceCollection):
            self._collection = sources
        else:
            raise TypeError(
                "'sources' must be a Source, list[Source], or SourceCollection."
            )

        self._criteria = inclusion_criteria or []

        base_cache_dir = Path(cache_dir) if cache_dir else None
        self._cache_dir = base_cache_dir / self._name if base_cache_dir else None
        if self._cache_dir and use_cache:
            self._cache_dir.mkdir(parents=True, exist_ok=True)
            self._use_cache = True
        else:
            logger.warning(
                "Caching is disabled for this cohort. To enable caching, provide a valid 'cache_dir' and set 'use_cache=True'."
            )
            self._use_cache = False
            self._cache_dir = None

        self._entity_ids: set[str] = self._resolve_entity_ids()
        self._entities_table: pl.DataFrame | None = None

    # ------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------

    @property
    def name(self) -> str:
        """Cohort name."""
        return self._name

    @property
    def entity_ids(self) -> set[str]:
        """Set of entity IDs in this cohort."""
        return set(self._entity_ids)

    @property
    def entity_id_list(self) -> list[str]:
        """Deterministically sorted entity IDs in this cohort."""
        return sorted(self._entity_ids)

    @property
    def cache_dir(self) -> Path | None:
        """Directory for caching cohort splits and metadata, or None if caching is disabled."""
        return self._cache_dir

    @property
    def use_cache(self) -> bool:
        """Whether to use caching for cohort splits."""
        return self._use_cache

    @property
    def criteria(self) -> list[EntityInclusionCriteria] | None:
        """List of `EntityInclusionCriteria` used to define this cohort, or None."""
        return list(self._criteria)

    @property
    def entities_table(self) -> pl.DataFrame | None:
        """Cached entities table (entity_id + static columns), if already built."""
        return self._entities_table

    @property
    def collection(self) -> SourceCollection:
        """Underlying source collection used by this cohort."""
        return self._collection

    # ------------------------------------------------------------------
    # Dunder methods
    # ------------------------------------------------------------------

    def __len__(self) -> int:
        return len(self._entity_ids)

    def __contains__(self, entity_id: str) -> bool:
        return entity_id in self._entity_ids

    def __repr__(self) -> str:
        return (
            f"Cohort("
            f"name={self._name}, "
            f"sources={self._collection.names}, "
            f"n_entities={len(self)}, "
            f"cache_dir={self._cache_dir}"
            f")"
        )

    # ------------------------------------------------------------------
    # Entity resolution
    # ------------------------------------------------------------------

    def _resolve_entity_ids(self) -> set[str]:
        """Apply inclusion criteria and return the surviving entity ID set.

        Starts from the union of all entity IDs across the collection, then
        applies each ``EntityInclusionCriteria`` in sequence.  Only criteria
        with ``required=True`` have a filtering effect.

        Returns:
            Set of entity IDs satisfying all inclusion criteria.

        Raises:
            KeyError: If a criteria references a source not in the collection.
        """
        try:
            candidates = self._collection.get_all_entity_ids()
        except SchemaError as exc:
            raise SchemaError(
                f"Failed to resolve entity IDs for cohort '{self.name}' because "
                f"a source schema is invalid: {exc}"
            ) from exc

        logger.info(
            "Candidate pool: %d entities (union across all sources)", len(candidates)
        )

        if not self._criteria:
            logger.info("No inclusion criteria provided; keeping all candidates.")
            return candidates

        for criteria in self._criteria:
            if criteria.source_name not in self._collection:
                raise KeyError(
                    f"Inclusion criteria references unknown source "
                    f"'{criteria.source_name}'. Available: {self._collection.names}"
                )

            if not criteria.required:
                continue

            source = self._collection[criteria.source_name]
            id_col = source.config.id_col
            before = len(candidates)

            try:
                qualifying = (
                    source.scan()
                    .group_by(id_col)
                    .agg(pl.len().alias("_n_events"))
                    .filter(pl.col("_n_events") >= criteria.min_events)
                    .pipe(
                        lambda lf: (
                            lf.filter(pl.col("_n_events") <= criteria.max_events)
                            if criteria.max_events is not None
                            else lf
                        )
                    )
                    .select(id_col)
                    .collect()
                    .get_column(id_col)
                    .to_list()
                )
            except SchemaError as exc:
                raise SchemaError(
                    f"Failed to apply inclusion criteria for cohort '{self.name}' on "
                    f"source '{source.name}': {exc}"
                ) from exc

            candidates &= set(qualifying)

            logger.info(
                "After criteria for '%s' (min=%s, max=%s): %d%d entities",
                criteria.source_name,
                criteria.min_events,
                criteria.max_events,
                before,
                len(candidates),
            )

            if not candidates:
                logger.warning(
                    "Cohort '%s' resolved to 0 entities after required inclusion "
                    "criteria on source '%s' (min_events=%s, max_events=%s).",
                    self.name,
                    criteria.source_name,
                    criteria.min_events,
                    criteria.max_events,
                )
                break

        logger.info("Resolved cohort: %d entities", len(candidates))
        return candidates

    # ------------------------------------------------------------------
    # Public filtering helpers
    # ------------------------------------------------------------------

    def filter_df(self, df: pl.DataFrame, entity_id_col: str = "entity_id") -> pl.DataFrame:
        """Filter a DataFrame to only cohort entities."""
        return df.filter(pl.col(entity_id_col).cast(pl.Utf8).is_in(self.entity_id_list))

    def filter_source(self, source: Source) -> pl.LazyFrame:
        """Filter a source scan to only cohort entities."""
        return source.scan().filter(
            pl.col(source.config.id_col).cast(pl.Utf8).is_in(self.entity_id_list)
        )

    # ------------------------------------------------------------------
    # Entities table
    # ------------------------------------------------------------------

    def build_entities_table(self, force_recompute: bool = False) -> pl.DataFrame:
        """Build or load entity_id plus static properties for this cohort."""
        entities_path = self._entities_table_path()
        metadata_path = self._entities_metadata_path()

        if (
            self.use_cache
            and not force_recompute
            and entities_path is not None
            and entities_path.exists()
        ):
            logger.info("Loading entities table from cache: %s", entities_path)
            self._entities_table = pl.read_parquet(entities_path)
            return self._entities_table

        entity_df = pl.DataFrame(
            {"entity_id": pl.Series(self.entity_id_list, dtype=pl.Utf8)}
        )

        for source in self._collection:
            static_slice = self._build_source_static_slice(source)
            if static_slice is not None:
                entity_df = entity_df.join(static_slice, on="entity_id", how="left")

        entity_df = entity_df.sort("entity_id")
        self._entities_table = entity_df

        if self.use_cache and entities_path is not None and metadata_path is not None:
            entities_path.parent.mkdir(parents=True, exist_ok=True)
            entity_df.write_parquet(entities_path)
            metadata = {
                "cohort_name": self.name,
                "n_entities": entity_df.height,
                "generated_at_utc": datetime.now(timezone.utc).isoformat(),
                "criteria_hash": self._criteria_hash(),
                "source_config_hashes": {
                    source.name: self._stable_hash(source.config.model_dump_json())
                    for source in self._collection
                },
                "columns": entity_df.columns,
            }
            metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")

        return entity_df

    # ------------------------------------------------------------------
    # Split building
    # ------------------------------------------------------------------

    def build_or_load_splits(
        self,
        split_config: CohortConfig | None = None,
        force_recompute: bool = False,
    ) -> pl.DataFrame:
        """Build or load cohort splits with full static context per entity."""
        cfg = split_config or CohortConfig(
            train_frac=0.7,
            val_frac=0.15,
            test_frac=0.15,
        )
        split_path = self._split_table_path(cfg)
        metadata_path = self._split_metadata_path(cfg)

        if (
            self.use_cache
            and not force_recompute
            and split_path is not None
            and split_path.exists()
        ):
            logger.info("Loading split table from cache: %s", split_path)
            return pl.read_parquet(split_path)

        entities = self.build_entities_table(force_recompute=force_recompute)
        split_labels = self._assign_splits(entities, cfg)
        split_df = entities.with_columns(pl.Series(name="split", values=split_labels))

        if self.use_cache and split_path is not None and metadata_path is not None:
            split_path.parent.mkdir(parents=True, exist_ok=True)
            split_df.write_parquet(split_path)
            split_counts = (
                split_df.group_by("split").len().rename({"len": "count"}).to_dicts()
            )
            metadata = {
                "cohort_name": self.name,
                "split_config": cfg.model_dump(exclude_none=False),
                "split_config_hash": cfg.config_hash(),
                "n_entities": split_df.height,
                "split_counts": split_counts,
                "generated_at_utc": datetime.now(timezone.utc).isoformat(),
            }
            metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")

        return split_df


    def load_split_config(self, split_hash: str) -> CohortConfig:
        """Reconstruct the CohortConfig used to generate a specific split.

        Args:
            split_hash: Hash from ``CohortConfig.config_hash()``, as stored
                in vocabulary metadata.

        Returns:
            The original ``CohortConfig`` used when the split was built.

        Raises:
            FileNotFoundError: If no cached split metadata exists for this hash.
            ValueError: If cache is disabled on this cohort.
        """
        if not self.use_cache or self._cache_dir is None:
            raise ValueError("Cannot load split config — caching is disabled on this cohort.")
        meta_path = self._cache_dir / "splits" / split_hash / "metadata.json"
        if not meta_path.exists():
            raise FileNotFoundError(
                f"No cached split metadata found for hash '{split_hash}' at {meta_path}. "
                "Ensure the cohort was built with caching enabled."
            )
        meta = json.loads(meta_path.read_text(encoding="utf-8"))
        return CohortConfig(**meta["split_config"])

    # ------------------------------------------------------------------
    # Cache paths
    # ------------------------------------------------------------------

    def _entities_table_path(self) -> Path | None:
        if not self.cache_dir:
            return None
        return self.cache_dir / "entities" / "entities_with_static.parquet"

    def _entities_metadata_path(self) -> Path | None:
        if not self.cache_dir:
            return None
        return self.cache_dir / "entities" / "metadata.json"

    def _split_table_path(self, cfg: CohortConfig) -> Path | None:
        if not self.cache_dir:
            return None
        return self.cache_dir / "splits" / cfg.config_hash() / "entities_split.parquet"

    def _split_metadata_path(self, cfg: CohortConfig) -> Path | None:
        if not self.cache_dir:
            return None
        return self.cache_dir / "splits" / cfg.config_hash() / "metadata.json"

    def vocabulary_cache_dir(self, vocab_hash: str) -> Path | None:
        """Return cache directory for a specific vocabulary artifact hash."""
        if not self.cache_dir:
            return None
        return self.cache_dir / "vocabulary" / vocab_hash

    # ------------------------------------------------------------------
    # Static attribute resolution
    # ------------------------------------------------------------------

    def _build_source_static_slice(self, source: Source) -> pl.DataFrame | None:
        id_col = source.config.id_col
        static_cols = self._source_static_columns(source)
        if not static_cols:
            return None

        try:
            scan = source.scan()
        except SchemaError as exc:
            raise SchemaError(
                f"Failed to build entities table for cohort '{self.name}' because "
                f"source '{source.name}' has an invalid schema: {exc}"
            ) from exc

        available_columns = set(scan.collect_schema().names())
        missing_static = [col for col in static_cols if col not in available_columns]
        if missing_static:
            raise SchemaError(
                f"Failed to build entities table for cohort '{self.name}' because "
                f"source '{source.name}' is missing static columns: {missing_static}"
            )

        rename_map = {col: f"{source.name}__{col}" for col in static_cols}
        agg_exprs = [pl.col(col).first().alias(col) for col in static_cols]

        return (
            scan
            .select([id_col, *static_cols])
            .filter(pl.col(id_col).is_in(self.entity_id_list))
            .group_by(id_col, maintain_order=True)
            .agg(agg_exprs)
            .select([id_col, *static_cols])
            .rename(rename_map)
            .rename({id_col: "entity_id"})
            .collect()
        )

    def _source_static_columns(self, source: Source) -> list[str]:
        cols: list[str] = []
        for group in (
            source.config.temporal_cols,
            source.config.categorical_cols,
            source.config.continuous_cols,
        ):
            if not group:
                continue
            cols.extend(col.col_name for col in group if col.static)
        return cols

    # ------------------------------------------------------------------
    # Split assignment helpers
    # ------------------------------------------------------------------

    def _assign_splits(self, entities: pl.DataFrame, cfg: CohortConfig) -> list[str]:
        if entities.height == 0:
            return []
        if not cfg.use_splits:
            return ["all"] * entities.height

        if cfg.stratify_col is None:
            return self._labels_for_indices(list(range(entities.height)), cfg, cfg.seed)

        if cfg.stratify_col not in entities.columns:
            raise ValueError(
                f"stratify_col '{cfg.stratify_col}' was not found in entities table columns: {entities.columns}"
            )

        stratum_values = entities.get_column(cfg.stratify_col).to_list()
        by_stratum: dict[str, list[int]] = {}
        for idx, value in enumerate(stratum_values):
            key = self._normalize_stratum_key(value)
            by_stratum.setdefault(key, []).append(idx)

        labels = [""] * entities.height
        for key, indices in by_stratum.items():
            stratum_seed = cfg.seed + int(self._stable_hash(key), 16)
            stratum_labels = self._labels_for_indices(indices, cfg, stratum_seed)
            for idx, split in zip(indices, stratum_labels):
                labels[idx] = split

        return labels

    def _labels_for_indices(
        self,
        indices: list[int],
        cfg: CohortConfig,
        seed: int,
    ) -> list[str]:
        n = len(indices)
        if n == 0:
            return []

        train_n = int(np.floor(n * cfg.train_frac))
        val_n = int(np.floor(n * cfg.val_frac))
        test_n = n - train_n - val_n

        labels = np.array(["test"] * n, dtype=object)
        labels[:train_n] = "train"
        labels[train_n : train_n + val_n] = "val"

        rng = np.random.default_rng(seed)
        rng.shuffle(labels)
        return labels.tolist()

    @staticmethod
    def _normalize_stratum_key(value: Any) -> str:
        return "<NULL>" if value is None else str(value)

    def _criteria_hash(self) -> str:
        payload = [criteria.model_dump(exclude_none=False) for criteria in self._criteria]
        return self._stable_hash(json.dumps(payload, sort_keys=True))

    @staticmethod
    def _stable_hash(payload: str) -> str:
        return hashlib.sha256(payload.encode()).hexdigest()[:16]

name property

name: str

Cohort name.

entity_ids property

entity_ids: set[str]

Set of entity IDs in this cohort.

entity_id_list property

entity_id_list: list[str]

Deterministically sorted entity IDs in this cohort.

cache_dir property

cache_dir: Path | None

Directory for caching cohort splits and metadata, or None if caching is disabled.

use_cache property

use_cache: bool

Whether to use caching for cohort splits.

criteria property

criteria: list[EntityInclusionCriteria] | None

List of EntityInclusionCriteria used to define this cohort, or None.

entities_table property

entities_table: DataFrame | None

Cached entities table (entity_id + static columns), if already built.

collection property

collection: SourceCollection

Underlying source collection used by this cohort.

__init__

__init__(name: str, sources: Source | list[Source] | SourceCollection, inclusion_criteria: list[EntityInclusionCriteria] | None = None, cache_dir: str | Path = Path('data/cohorts/'), use_cache: bool = True) -> None

Initialize a Cohort with given sources and configuration. Args: name: Unique name for this cohort, used in caching and metadata. sources: One or more Source objects or a SourceCollection defining the data sources for this cohort. inclusion_criteria: Optional list of EntityInclusionCriteria to filter entities. cache_dir: Base directory for caching cohort artifacts. If None, caching is disabled. use_cache: Whether to enable caching for this cohort. If False, no caching will occur even if cache_dir is provided. Raises: ValueError: If name is empty or contains only whitespace, or if inclusion_criteria contains invalid entries. KeyError: If inclusion_criteria references a source not in the collection. TypeError: If sources is not a valid type (Source, list[Source], or SourceCollection).

Source code in tab2seq/cohort/core.py
def __init__(
    self,
    name: str,
    sources: Source | list[Source] | SourceCollection,
    inclusion_criteria: list[EntityInclusionCriteria] | None = None,
    cache_dir: str | Path = Path("data/cohorts/"),
    use_cache: bool = True,
) -> None:
    """Initialize a Cohort with given sources and configuration.
    Args:
        name: Unique name for this cohort, used in caching and metadata.
        sources: One or more `Source` objects or a `SourceCollection` defining the data sources for this cohort.
        inclusion_criteria: Optional list of `EntityInclusionCriteria` to filter entities.
        cache_dir: Base directory for caching cohort artifacts. If None, caching is disabled.
        use_cache: Whether to enable caching for this cohort. If False, no caching will occur
            even if `cache_dir` is provided.
    Raises:
        ValueError: If `name` is empty or contains only whitespace, or if `inclusion_criteria` contains invalid entries.
        KeyError: If `inclusion_criteria` references a source not in the collection.
        TypeError: If `sources` is not a valid type (Source, list[Source], or SourceCollection).
    """
    if not isinstance(name, str) or not name.strip():
        raise ValueError("'name' must be a non-empty string.")
    if name != name.strip():
        raise ValueError("'name' cannot have leading or trailing whitespace.")

    self._name = name

    if isinstance(sources, Source):
        self._collection = SourceCollection([sources])
    elif isinstance(sources, list):
        self._collection = SourceCollection(sources)
    elif isinstance(sources, SourceCollection):
        self._collection = sources
    else:
        raise TypeError(
            "'sources' must be a Source, list[Source], or SourceCollection."
        )

    self._criteria = inclusion_criteria or []

    base_cache_dir = Path(cache_dir) if cache_dir else None
    self._cache_dir = base_cache_dir / self._name if base_cache_dir else None
    if self._cache_dir and use_cache:
        self._cache_dir.mkdir(parents=True, exist_ok=True)
        self._use_cache = True
    else:
        logger.warning(
            "Caching is disabled for this cohort. To enable caching, provide a valid 'cache_dir' and set 'use_cache=True'."
        )
        self._use_cache = False
        self._cache_dir = None

    self._entity_ids: set[str] = self._resolve_entity_ids()
    self._entities_table: pl.DataFrame | None = None

filter_df

filter_df(df: DataFrame, entity_id_col: str = 'entity_id') -> pl.DataFrame

Filter a DataFrame to only cohort entities.

Source code in tab2seq/cohort/core.py
def filter_df(self, df: pl.DataFrame, entity_id_col: str = "entity_id") -> pl.DataFrame:
    """Filter a DataFrame to only cohort entities."""
    return df.filter(pl.col(entity_id_col).cast(pl.Utf8).is_in(self.entity_id_list))

filter_source

filter_source(source: Source) -> pl.LazyFrame

Filter a source scan to only cohort entities.

Source code in tab2seq/cohort/core.py
def filter_source(self, source: Source) -> pl.LazyFrame:
    """Filter a source scan to only cohort entities."""
    return source.scan().filter(
        pl.col(source.config.id_col).cast(pl.Utf8).is_in(self.entity_id_list)
    )

build_entities_table

build_entities_table(force_recompute: bool = False) -> pl.DataFrame

Build or load entity_id plus static properties for this cohort.

Source code in tab2seq/cohort/core.py
def build_entities_table(self, force_recompute: bool = False) -> pl.DataFrame:
    """Build or load entity_id plus static properties for this cohort."""
    entities_path = self._entities_table_path()
    metadata_path = self._entities_metadata_path()

    if (
        self.use_cache
        and not force_recompute
        and entities_path is not None
        and entities_path.exists()
    ):
        logger.info("Loading entities table from cache: %s", entities_path)
        self._entities_table = pl.read_parquet(entities_path)
        return self._entities_table

    entity_df = pl.DataFrame(
        {"entity_id": pl.Series(self.entity_id_list, dtype=pl.Utf8)}
    )

    for source in self._collection:
        static_slice = self._build_source_static_slice(source)
        if static_slice is not None:
            entity_df = entity_df.join(static_slice, on="entity_id", how="left")

    entity_df = entity_df.sort("entity_id")
    self._entities_table = entity_df

    if self.use_cache and entities_path is not None and metadata_path is not None:
        entities_path.parent.mkdir(parents=True, exist_ok=True)
        entity_df.write_parquet(entities_path)
        metadata = {
            "cohort_name": self.name,
            "n_entities": entity_df.height,
            "generated_at_utc": datetime.now(timezone.utc).isoformat(),
            "criteria_hash": self._criteria_hash(),
            "source_config_hashes": {
                source.name: self._stable_hash(source.config.model_dump_json())
                for source in self._collection
            },
            "columns": entity_df.columns,
        }
        metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")

    return entity_df

build_or_load_splits

build_or_load_splits(split_config: CohortConfig | None = None, force_recompute: bool = False) -> pl.DataFrame

Build or load cohort splits with full static context per entity.

Source code in tab2seq/cohort/core.py
def build_or_load_splits(
    self,
    split_config: CohortConfig | None = None,
    force_recompute: bool = False,
) -> pl.DataFrame:
    """Build or load cohort splits with full static context per entity."""
    cfg = split_config or CohortConfig(
        train_frac=0.7,
        val_frac=0.15,
        test_frac=0.15,
    )
    split_path = self._split_table_path(cfg)
    metadata_path = self._split_metadata_path(cfg)

    if (
        self.use_cache
        and not force_recompute
        and split_path is not None
        and split_path.exists()
    ):
        logger.info("Loading split table from cache: %s", split_path)
        return pl.read_parquet(split_path)

    entities = self.build_entities_table(force_recompute=force_recompute)
    split_labels = self._assign_splits(entities, cfg)
    split_df = entities.with_columns(pl.Series(name="split", values=split_labels))

    if self.use_cache and split_path is not None and metadata_path is not None:
        split_path.parent.mkdir(parents=True, exist_ok=True)
        split_df.write_parquet(split_path)
        split_counts = (
            split_df.group_by("split").len().rename({"len": "count"}).to_dicts()
        )
        metadata = {
            "cohort_name": self.name,
            "split_config": cfg.model_dump(exclude_none=False),
            "split_config_hash": cfg.config_hash(),
            "n_entities": split_df.height,
            "split_counts": split_counts,
            "generated_at_utc": datetime.now(timezone.utc).isoformat(),
        }
        metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")

    return split_df

load_split_config

load_split_config(split_hash: str) -> CohortConfig

Reconstruct the CohortConfig used to generate a specific split.

Parameters:

Name Type Description Default
split_hash str

Hash from CohortConfig.config_hash(), as stored in vocabulary metadata.

required

Returns:

Type Description
CohortConfig

The original CohortConfig used when the split was built.

Raises:

Type Description
FileNotFoundError

If no cached split metadata exists for this hash.

ValueError

If cache is disabled on this cohort.

Source code in tab2seq/cohort/core.py
def load_split_config(self, split_hash: str) -> CohortConfig:
    """Reconstruct the CohortConfig used to generate a specific split.

    Args:
        split_hash: Hash from ``CohortConfig.config_hash()``, as stored
            in vocabulary metadata.

    Returns:
        The original ``CohortConfig`` used when the split was built.

    Raises:
        FileNotFoundError: If no cached split metadata exists for this hash.
        ValueError: If cache is disabled on this cohort.
    """
    if not self.use_cache or self._cache_dir is None:
        raise ValueError("Cannot load split config — caching is disabled on this cohort.")
    meta_path = self._cache_dir / "splits" / split_hash / "metadata.json"
    if not meta_path.exists():
        raise FileNotFoundError(
            f"No cached split metadata found for hash '{split_hash}' at {meta_path}. "
            "Ensure the cohort was built with caching enabled."
        )
    meta = json.loads(meta_path.read_text(encoding="utf-8"))
    return CohortConfig(**meta["split_config"])

vocabulary_cache_dir

vocabulary_cache_dir(vocab_hash: str) -> Path | None

Return cache directory for a specific vocabulary artifact hash.

Source code in tab2seq/cohort/core.py
def vocabulary_cache_dir(self, vocab_hash: str) -> Path | None:
    """Return cache directory for a specific vocabulary artifact hash."""
    if not self.cache_dir:
        return None
    return self.cache_dir / "vocabulary" / vocab_hash