Skip to content

Sparse Aware MIL (sAwMIL)

sawmil.sawmil.sAwMIL dataclass

sAwMIL(
    C: float = 1.0,
    kernel: KernelType = "Linear",
    normalizer: str = "none",
    p: float = 1.0,
    fast_linear: bool = True,
    smil_scale_C: bool = True,
    tol: float = 1e-08,
    verbose: bool = False,
    solver: str = "gurobi",
    eta: float = 0.1,
    min_pos_ratio: float = 0.05,
    smil_: sMIL | None = None,
    sil_: SVM | None = None,
    classes_: NDArray[float64] | None = None,
    coef_: NDArray[float64] | None = None,
    intercept_: float | None = None,
    cutoff_: float | None = None,
)

Bases: BaseEstimator, ClassifierMixin

Sparse Aware MIL (SVM)

decision_function

decision_function(
    bags: Sequence[Bag] | BagDataset | Sequence[ndarray],
) -> npt.NDArray[np.float64]

Compute the decision function for the given bags.

Source code in src/sawmil/sawmil.py
202
203
204
205
206
207
208
209
210
211
212
213
def decision_function(self, bags: Sequence[Bag] | BagDataset | Sequence[np.ndarray]) -> npt.NDArray[np.float64]:
    '''Compute the decision function for the given bags.'''
    blist = self._coerce_bags(bags)
    if self.sil_ is None:
        raise RuntimeError("sAwMIL is not fitted.")
    scores = np.empty(len(blist), dtype=float)
    for i, b in enumerate(blist):
        if b.n == 0:
            scores[i] = float(self.sil_.intercept_ or 0.0)
        else:
            scores[i] = float(np.max(self.sil_.decision_function(b.X)))
    return scores

fit

fit(
    bags: Sequence[Bag] | BagDataset | Sequence[ndarray],
    y: Optional[NDArray[float64]] = None,
    intra_bag_labels: Optional[Sequence[ndarray]] = None,
) -> "sAwMIL"

Fit the model to the training data.

Source code in src/sawmil/sawmil.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def fit(
    self,
    bags: Sequence[Bag] | BagDataset | Sequence[np.ndarray],
    y: Optional[npt.NDArray[np.float64]] = None,
    intra_bag_labels: Optional[Sequence[np.ndarray]] = None,
) -> "sAwMIL":
    '''Fit the model to the training data.'''
    # 1) coerce input
    blist = self._coerce_bags(bags, y, intra_bag_labels)
    if not blist:
        raise ValueError("No bags provided.")

    # 2) sMIL (stage 1) — use its decision on singletons to rank instances
    smil = sMIL(
        C=self.C,
        kernel=self.kernel,
        normalizer=self.normalizer,
        p=self.p,
        use_intra_labels=True,            # you said you want to enforce using intra labels
        fast_linear=self.fast_linear,
        scale_C=self.smil_scale_C,        # <-- use the sMIL-specific flag
        tol=self.tol,
        verbose=self.verbose,
    )
    smil.fit(blist)
    self.smil_ = smil

    # 3) gather all instances
    X_all, y_bag, mask, _ = self._flatten(blist)
    if X_all.shape[0] == 0:
        # degenerate
        self.sil_ = SVM(C=self.C, kernel="linear",
                        solver=self.solver, tol=self.tol, verbose=self.verbose)
        self.sil_.coef_, self.sil_.intercept_ = np.zeros(
            (blist[0].d,)), 0.0
        self.coef_, self.intercept_, self.cutoff_ = self.sil_.coef_, self.sil_.intercept_, 0.0
        return self

    # split by bag label
    pos_inst = (y_bag > 0)
    X_pos = X_all[pos_inst]
    mask_pos = mask[pos_inst]
    X_neg = X_all[~pos_inst]

    # no positives? fall back to all negative labels for SIL
    if X_pos.shape[0] == 0:
        X_sil = X_neg
        y_sil = -np.ones(X_neg.shape[0], dtype=float)
        sil = SVM(C=self.C, kernel=self.kernel, solver=self.solver,
                  tol=self.tol, verbose=self.verbose)
        sil.fit(X_sil, y_sil)
        self.sil_ = sil
        self.coef_ = sil.coef_.ravel() if sil.coef_ is not None else None
        self.intercept_ = float(
            sil.intercept_) if sil.intercept_ is not None else None
        self.cutoff_ = float("-inf")
        return self

    # 4) score positive-bag instances with sMIL (as singleton bags)
    pos_singletons = self._singletonize(X_pos, y=+1.0)
    S_pos = smil.decision_function(pos_singletons).ravel()

    # 5) select top-eta under the intra-label mask
    eta = float(self.eta)
    eta = min(max(eta, 1e-9), 1.0)
    if S_pos.size == 0:
        q = float("-inf")
        chosen = np.zeros(0, dtype=bool)
    else:
        q = float(np.quantile(S_pos, 1.0 - eta, method="linear"))
        chosen = (S_pos >= q) & (mask_pos >= 0.5)
        # fallback: ensure at least min_pos_ratio positives overall
        min_needed = max(1, int(self.min_pos_ratio * len(S_pos)))
        if chosen.sum() < min_needed:
            k = min(len(S_pos), max(
                min_needed, int(round(eta * len(S_pos)))))
            topk = np.argsort(-S_pos)[:k]
            chosen = np.zeros_like(chosen)
            chosen[topk] = True

    self.cutoff_ = q

    # 6) build SIL dataset
    y_pos = np.full(X_pos.shape[0], -1.0, dtype=float)
    y_pos[chosen] = +1.0
    X_sil = np.vstack([X_neg, X_pos])
    y_sil = np.hstack([-np.ones(X_neg.shape[0], dtype=float), y_pos])

    # 7) train instance SVM (stage 2) — pass solver here
    sil = SVM(
        C=self.C,
        kernel=self.kernel,
        solver=self.solver,
        tol=self.tol,
        verbose=self.verbose,
    )
    sil.fit(X_sil, y_sil)
    self.sil_ = sil

    self.coef_ = sil.coef_.ravel() if sil.coef_ is not None else None
    self.intercept_ = float(
        sil.intercept_) if sil.intercept_ is not None else None
    return self

predict

predict(
    bags: Sequence[Bag] | BagDataset | Sequence[ndarray],
) -> npt.NDArray[np.float64]

Predict the labels for the given bags.

Source code in src/sawmil/sawmil.py
215
216
217
def predict(self, bags: Sequence[Bag] | BagDataset | Sequence[np.ndarray]) -> npt.NDArray[np.float64]:
    '''Predict the labels for the given bags.'''
    return (self.decision_function(bags) >= 0.0).astype(float)

score

score(bags, y_true) -> float

Compute the accuracy of the model on the given bags.

Source code in src/sawmil/sawmil.py
219
220
221
222
223
224
225
226
227
228
def score(self, bags, y_true) -> float:
    '''Compute the accuracy of the model on the given bags.'''
    y_pred = self.predict(bags)
    if isinstance(bags, BagDataset):
        y_true_arr = np.asarray([b.y for b in bags.bags], dtype=float)
    elif len(bags) and isinstance(bags[0], Bag):  # type: ignore[index]
        y_true_arr = np.asarray([b.y for b in bags], dtype=float)
    else:
        y_true_arr = np.asarray(y_true, dtype=float)
    return float((y_pred == y_true_arr).mean())