Skip to content

zamba.images.config

ImageClassificationPredictConfig

Bases: ZambaImageConfig

Configuration for using an image classification model for inference.

Parameters:

Name Type Description Default
data_dir DirectoryPath

Path to a directory containing images for inference. Defaults to the current working directory.

required
filepaths FilePath

Path to a CSV containing images for inference, with one row per image in the data_dir. There must be a column called 'filepath' (absolute or relative to the data_dir). If None, uses all image files in data_dir. Defaults to None.

required
checkpoint FilePath

Path to a custom checkpoint file (.ckpt) generated by zamba that can be used to generate predictions. If None, defaults to a pretrained model. Defaults to None.

required
model_name str

Name of the model to use for inference. Currently supports 'lila.science'. Defaults to 'lila.science'.

required
save bool

Whether to save out predictions. If False, predictions are not saved. Defaults to True.

required
save_dir Path

An optional directory in which to save the model predictions and configuration yaml. If no save_dir is specified and save=True, outputs will be written to the current working directory. Defaults to None.

required
overwrite bool

If True, overwrite outputs in save_dir if they exist. Defaults to False.

required
crop_images bool

Whether to preprocess images using Megadetector or bounding boxes from labels file. Focuses the model on regions of interest. Defaults to True.

required
detections_threshold float

Confidence threshold for Megadetector detections. Only applied when crop_images=True. Defaults to 0.2.

required
gpus int

Number of GPUs to use for inference. Defaults to all of the available GPUs found on the machine.

required
num_workers int

Number of subprocesses to use for data loading. Defaults to 3.

required
image_size int

Image size (height and width) for the input to the classification model. Defaults to 224.

required
results_file_format ResultsFormat

The format in which to output the predictions. Currently 'csv' and 'megadetector' JSON formats are supported. Default is 'csv'.

required
results_file_name Path

The filename for the output predictions in the save directory. Defaults to 'zamba_predictions.csv' or 'zamba_predictions.json' depending on results_file_format.

required
model_cache_dir Path

Cache directory where downloaded model weights will be saved. If None and no environment variable is set, will use your default cache directory. Defaults to None.

required
weight_download_region str

s3 region to download pretrained weights from. Options are "us" (United States), "eu" (Europe), or "asia" (Asia Pacific). Defaults to "us".

required
Source code in zamba/images/config.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
class ImageClassificationPredictConfig(ZambaImageConfig):
    """
    Configuration for using an image classification model for inference.

    Args:
        data_dir (DirectoryPath): Path to a directory containing images for
            inference. Defaults to the current working directory.
        filepaths (FilePath, optional): Path to a CSV containing images for inference, with
            one row per image in the data_dir. There must be a column called
            'filepath' (absolute or relative to the data_dir). If None, uses
            all image files in data_dir. Defaults to None.
        checkpoint (FilePath, optional): Path to a custom checkpoint file (.ckpt)
            generated by zamba that can be used to generate predictions. If None,
            defaults to a pretrained model. Defaults to None.
        model_name (str, optional): Name of the model to use for inference. Currently
            supports 'lila.science'. Defaults to 'lila.science'.
        save (bool): Whether to save out predictions. If False, predictions are
            not saved. Defaults to True.
        save_dir (Path, optional): An optional directory in which to save the model
             predictions and configuration yaml. If no save_dir is specified and save=True,
             outputs will be written to the current working directory. Defaults to None.
        overwrite (bool): If True, overwrite outputs in save_dir if they exist.
            Defaults to False.
        crop_images (bool): Whether to preprocess images using Megadetector or bounding boxes
            from labels file. Focuses the model on regions of interest. Defaults to True.
        detections_threshold (float): Confidence threshold for Megadetector detections.
            Only applied when crop_images=True. Defaults to 0.2.
        gpus (int): Number of GPUs to use for inference.
            Defaults to all of the available GPUs found on the machine.
        num_workers (int): Number of subprocesses to use for data loading.
            Defaults to 3.
        image_size (int, optional): Image size (height and width) for the input to the
            classification model. Defaults to 224.
        results_file_format (ResultsFormat): The format in which to output the predictions.
            Currently 'csv' and 'megadetector' JSON formats are supported. Default is 'csv'.
        results_file_name (Path, optional): The filename for the output predictions in the
            save directory. Defaults to 'zamba_predictions.csv' or 'zamba_predictions.json'
            depending on results_file_format.
        model_cache_dir (Path, optional): Cache directory where downloaded model weights
            will be saved. If None and no environment variable is set, will use your
            default cache directory. Defaults to None.
        weight_download_region (str): s3 region to download pretrained weights from.
            Options are "us" (United States), "eu" (Europe), or "asia" (Asia Pacific).
            Defaults to "us".
    """

    checkpoint: Optional[FilePath] = None
    model_name: Optional[str] = ImageModelEnum.LILA_SCIENCE.value
    filepaths: Optional[Union[FilePath, pd.DataFrame]] = None
    data_dir: DirectoryPath
    save: bool = True
    overwrite: bool = False
    crop_images: bool = True
    detections_threshold: float = 0.2
    gpus: int = GPUS_AVAILABLE
    num_workers: int = 3
    image_size: Optional[int] = 224
    results_file_format: ResultsFormat = ResultsFormat.CSV
    results_file_name: Optional[Path] = Path("zamba_predictions.csv")
    model_cache_dir: Optional[Path] = None
    weight_download_region: str = RegionEnum.us.value

    class Config:  # type: ignore
        arbitrary_types_allowed = True

    _validate_model_cache_dir = validator("model_cache_dir", allow_reuse=True, always=True)(
        validate_model_cache_dir
    )

    _get_filepaths = root_validator(allow_reuse=True, pre=False, skip_on_failure=True)(
        get_image_filepaths
    )

    @root_validator(skip_on_failure=True)
    def validate_save_dir(cls, values):
        save_dir = values["save_dir"]
        results_file_name = values["results_file_name"]
        save = values["save"]

        # if no save_dir but save is True, use current working directory
        if save_dir is None and save:
            save_dir = Path.cwd()

        if save_dir is not None:
            # check if files exist
            save_path = save_dir / results_file_name
            if values["results_file_format"] == ResultsFormat.MEGADETECTOR:
                save_path = save_path.with_suffix(".json")
            if save_path.exists() and not values["overwrite"]:
                raise ValueError(
                    f"{save_path.name} already exists in {save_dir}. If you would like to overwrite, set overwrite=True"
                )

            # make a directory if needed
            save_dir.mkdir(parents=True, exist_ok=True)

            # set save to True if save_dir is set
            if not save:
                save = True

        values["save_dir"] = save_dir
        values["save"] = save

        return values

    @root_validator(skip_on_failure=True)
    def validate_filepaths(cls, values):
        if isinstance(values["filepaths"], pd.DataFrame):
            files_df = values["filepaths"]
        else:
            files_df = pd.DataFrame(pd.read_csv(values["filepaths"]))

        if "filepath" not in files_df.columns:
            raise ValueError(f"{values['filepath']} must contain a `filepath` column.")
        else:
            files_df = files_df[["filepath"]]

        duplicated = files_df.filepath.duplicated()
        if duplicated.sum() > 0:
            logger.warning(
                f"Found {duplicated.sum():,} duplicate row(s) in filepaths csv. Dropping duplicates so predictions will have one row per video."
            )
            files_df = files_df[["filepath"]].drop_duplicates()

        # The filepath column can come in as a str or a Path-like and either absolute
        # or relative to the data directory. Handle all those cases.
        filepaths = []
        for path in files_df["filepath"]:
            path = Path(path)
            if not path.is_absolute():
                # Assume relative to data directory
                path = values["data_dir"] / path
            filepaths.append(str(path))
        files_df["filepath"] = filepaths
        values["filepaths"] = files_df
        return values

    @root_validator(skip_on_failure=True)
    def validate_detections_threshold(cls, values):
        threshold = values["detections_threshold"]

        if threshold <= 0 or threshold >= 1:
            raise ValueError(
                "Detections threshold value should be greater than zero and less than one."
            )

        return values

    @root_validator(skip_on_failure=True)
    def validate_image_size(cls, values):
        if values["image_size"] <= 0:
            raise ValueError("Image size should be greater than or equal 64")
        return values

    _validate_model_name_and_checkpoint = root_validator(allow_reuse=True, skip_on_failure=True)(
        validate_model_name_and_checkpoint
    )

ImageClassificationTrainingConfig

Bases: ZambaImageConfig

Configuration for running image classification training.

Parameters:

Name Type Description Default
data_dir Path

Path to directory containing the training images.

required
labels Union[FilePath, DataFrame]

Path to a CSV or JSON file with labels, or a pandas DataFrame. For CSV files, must contain 'filepath' and 'label' columns. For JSON files, must be in COCO or other supported format as specified by labels_format.

required
labels_format BboxFormat

Format for bounding box annotations when labels are provided as JSON. Options are defined in the BboxFormat enum. Defaults to BboxFormat.COCO.

required
checkpoint FilePath

Path to a custom checkpoint file (.ckpt) generated by zamba that can be used to resume training. If None and from_scratch=False, will load a pretrained model. Defaults to None.

required
model_name str

Name of the model to use. Currently supports 'lila.science'. Defaults to 'lila.science'.

required
name str

Classification experiment name used for MLFlow tracking. Defaults to 'image-classification'.

required
max_epochs int

Maximum number of training epochs. Defaults to 100.

required
lr float

Learning rate. If None, will attempt to find a good learning rate. Defaults to None.

required
image_size int

Input image size (height and width) for the model. Defaults to 224.

required
batch_size int

Physical batch size for training. Defaults to 16.

required
accumulated_batch_size int

Virtual batch size for gradient accumulation. Useful to match batch sizes from published papers with constrained GPU memory. If None, uses batch_size. Defaults to None.

required
early_stopping_patience int

Number of epochs with no improvement after which training will be stopped. Defaults to 3.

required
extra_train_augmentations bool

Whether to use additional image augmentations. If false, uses simple transforms for camera trap imagery (random perspective shift, random horizontal flip, random rotation). If True, adds complex transforms beyond the basic set (random grayscale, equalize, etc.). Defaults to False.

required
num_workers int

Number of workers for data loading. Defaults to 2/3 of available CPU cores.

required
accelerator str

PyTorch Lightning accelerator type ('gpu' or 'cpu'). Defaults to 'gpu' if CUDA is available, otherwise 'cpu'.

required
devices Any

Which devices to use for training. Can be int, list of ints, or 'auto'. Defaults to 'auto'.

required
crop_images bool

Whether to preprocess images using Megadetector or bounding boxes from labels. Defaults to True.

required
detections_threshold float

Confidence threshold for Megadetector. Only used when crop_images=True and no bounding boxes are provided in labels. Defaults to 0.2.

required
checkpoint_path Path

Directory where training outputs will be saved. Defaults to current working directory.

required
weighted_loss bool

Whether to use class-weighted loss during training. Helpful for imbalanced datasets. Defaults to False.

required
mlflow_tracking_uri str

URI for MLFlow tracking server. Defaults to './mlruns'.

required
from_scratch bool

Whether to train the model from scratch (base weights) instead of using a pretrained checkpoint. Defaults to False.

required
use_default_model_labels bool

Whether to use the full set of default model labels or only the labels in the provided dataset. If set to False, will replace the model head for finetuning and output only the species in the provided labels file. If None, automatically determined based on the labels provided.

required
scheduler_config Union[str, SchedulerConfig]

Learning rate scheduler configuration. If 'default', uses the scheduler from original training. Defaults to 'default'.

required
split_proportions Dict[str, int]

Proportions for train/val/test splits if no split column is provided in labels. Defaults to {'train': 3, 'val': 1, 'test': 1}.

required
model_cache_dir Path

Directory where downloaded model weights will be cached. If None, uses the system's default cache directory. Defaults to None.

required
cache_dir Path

Directory where cropped/processed images will be cached. Defaults to a 'image_cache' subdirectory in the system's cache directory.

required
weight_download_region str

S3 region for downloading pretrained weights. Options are 'us', 'eu', or 'asia'. Defaults to 'us'.

required
species_in_label_order list

Optional list to specify the order of species labels in the model output. Defaults to None.

required
Source code in zamba/images/config.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
class ImageClassificationTrainingConfig(ZambaImageConfig):
    """Configuration for running image classification training.

    Args:
        data_dir (Path): Path to directory containing the training images.
        labels (Union[FilePath, pd.DataFrame]): Path to a CSV or JSON file with labels, or a pandas DataFrame.
            For CSV files, must contain 'filepath' and 'label' columns.
            For JSON files, must be in COCO or other supported format as specified by labels_format.
        labels_format (BboxFormat): Format for bounding box annotations when labels are provided as JSON.
            Options are defined in the BboxFormat enum. Defaults to BboxFormat.COCO.
        checkpoint (FilePath, optional): Path to a custom checkpoint file (.ckpt) generated by zamba
            that can be used to resume training. If None and from_scratch=False, will load a pretrained model.
            Defaults to None.
        model_name (str, optional): Name of the model to use. Currently supports 'lila.science'.
            Defaults to 'lila.science'.
        name (str, optional): Classification experiment name used for MLFlow tracking.
            Defaults to 'image-classification'.
        max_epochs (int): Maximum number of training epochs. Defaults to 100.
        lr (float, optional): Learning rate. If None, will attempt to find a good learning rate.
            Defaults to None.
        image_size (int): Input image size (height and width) for the model. Defaults to 224.
        batch_size (int, optional): Physical batch size for training. Defaults to 16.
        accumulated_batch_size (int, optional): Virtual batch size for gradient accumulation.
            Useful to match batch sizes from published papers with constrained GPU memory.
            If None, uses batch_size. Defaults to None.
        early_stopping_patience (int): Number of epochs with no improvement after which training
            will be stopped. Defaults to 3.
        extra_train_augmentations (bool): Whether to use additional image augmentations.
            If false, uses simple transforms for camera trap imagery (random perspective shift,
            random horizontal flip, random rotation).
            If True, adds complex transforms beyond the basic set (random grayscale, equalize, etc.).
            Defaults to False.
        num_workers (int): Number of workers for data loading. Defaults to 2/3 of available CPU cores.
        accelerator (str): PyTorch Lightning accelerator type ('gpu' or 'cpu').
            Defaults to 'gpu' if CUDA is available, otherwise 'cpu'.
        devices (Any): Which devices to use for training. Can be int, list of ints, or 'auto'.
            Defaults to 'auto'.
        crop_images (bool): Whether to preprocess images using Megadetector or bounding boxes
            from labels. Defaults to True.
        detections_threshold (float): Confidence threshold for Megadetector.
            Only used when crop_images=True and no bounding boxes are provided in labels.
            Defaults to 0.2.
        checkpoint_path (Path): Directory where training outputs will be saved.
            Defaults to current working directory.
        weighted_loss (bool): Whether to use class-weighted loss during training.
            Helpful for imbalanced datasets. Defaults to False.
        mlflow_tracking_uri (str, optional): URI for MLFlow tracking server.
            Defaults to './mlruns'.
        from_scratch (bool): Whether to train the model from scratch (base weights)
            instead of using a pretrained checkpoint. Defaults to False.
        use_default_model_labels (bool, optional): Whether to use the full set of default model
            labels or only the labels in the provided dataset.
            If set to False, will replace the model head for finetuning and output only
            the species in the provided labels file.
            If None, automatically determined based on the labels provided.
        scheduler_config (Union[str, SchedulerConfig], optional): Learning rate scheduler
            configuration. If 'default', uses the scheduler from original training.
            Defaults to 'default'.
        split_proportions (Dict[str, int], optional): Proportions for train/val/test splits
            if no split column is provided in labels. Defaults to {'train': 3, 'val': 1, 'test': 1}.
        model_cache_dir (Path, optional): Directory where downloaded model weights will be cached.
            If None, uses the system's default cache directory. Defaults to None.
        cache_dir (Path, optional): Directory where cropped/processed images will be cached.
            Defaults to a 'image_cache' subdirectory in the system's cache directory.
        weight_download_region (str): S3 region for downloading pretrained weights.
            Options are 'us', 'eu', or 'asia'. Defaults to 'us'.
        species_in_label_order (list, optional): Optional list to specify the order of
            species labels in the model output. Defaults to None.
    """

    data_dir: Path
    labels: Union[FilePath, pd.DataFrame]
    labels_format: BboxInputFormat = BboxInputFormat.COCO
    checkpoint: Optional[FilePath] = None
    model_name: Optional[str] = ImageModelEnum.LILA_SCIENCE.value
    name: Optional[str] = "image-classification"
    max_epochs: int = 100
    lr: Optional[float] = None  # if None, will find a good learning rate
    image_size: int = 224
    batch_size: Optional[int] = 16
    accumulated_batch_size: Optional[int] = None
    early_stopping_patience: int = 3
    extra_train_augmentations: bool = False
    num_workers: int = int(os.cpu_count() // 1.5)  # default use 2/3 of available cores
    accelerator: str = "gpu" if torch.cuda.is_available() else "cpu"
    devices: Any = "auto"
    crop_images: bool = True
    detections_threshold: float = 0.2
    checkpoint_path: Path = Path.cwd()
    weighted_loss: bool = False
    mlflow_tracking_uri: Optional[str] = "./mlruns"
    from_scratch: Optional[bool] = False
    use_default_model_labels: Optional[bool] = None
    scheduler_config: Optional[Union[str, SchedulerConfig]] = "default"
    split_proportions: Optional[Dict[str, int]] = {"train": 3, "val": 1, "test": 1}
    model_cache_dir: Optional[Path] = None
    cache_dir: Optional[Path] = Path(appdirs.user_cache_dir()) / "zamba" / "image_cache"
    weight_download_region: str = RegionEnum.us.value
    species_in_label_order: Optional[list] = None

    class Config:
        arbitrary_types_allowed = True

    _validate_model_cache_dir = validator("model_cache_dir", allow_reuse=True, always=True)(
        validate_model_cache_dir
    )

    @staticmethod
    def process_json_annotations(labels, labels_format: BboxInputFormat) -> pd.DataFrame:
        return bbox_json_to_df(labels, bbox_format=labels_format)

    @root_validator(skip_on_failure=True)
    def process_cache_dir(cls, values):
        cache_dir = values["cache_dir"]
        if not cache_dir.exists():
            cache_dir.mkdir(parents=True)
            logger.info("Cache dir created.")
        return values

    @root_validator(skip_on_failure=True)
    def validate_labels(cls, values):
        """Validate and load labels"""
        logger.info("Validating labels")

        if isinstance(values["labels"], pd.DataFrame):
            pass
        elif values["labels"].suffix == ".json":
            with open(values["labels"], "r") as f:
                values["labels"] = cls.process_json_annotations(
                    json.load(f), values["labels_format"]
                )
        else:
            values["labels"] = pd.read_csv(values["labels"])

        return values

    @root_validator(skip_on_failure=True)
    def validate_devices(cls, values):
        # per pytorch lightning docs, should be int or list of ints
        # https://lightning.ai/docs/pytorch/stable/common/trainer.html#devices
        raw_val = values["devices"]
        if "," in raw_val:
            values["devices"] = [int(v) for v in raw_val]
        elif raw_val == "auto":
            pass
        else:
            values["devices"] = int(raw_val)

        return values

    @root_validator(skip_on_failure=True)
    def validate_data_dir(cls, values):
        if not os.path.exists(values["data_dir"]):
            raise ValueError("Data dir doesn't exist.")
        return values

    @root_validator(skip_on_failure=True)
    def validate_image_files(cls, values):
        """Validate and load image files."""
        logger.info("Validating image files exist")

        exists = process_map(
            cls._validate_filepath,
            (values["data_dir"] / values["labels"].filepath.path).items(),
            chunksize=max(
                1, len(values["labels"]) // 1000
            ),  # chunks can be large; should be fast operation
            total=len(values["labels"]),
        )

        file_existence = pd.DataFrame(exists).set_index(0)
        exists = file_existence[2]

        if not exists.all():
            missing_files = file_existence[~exists]
            example_missing = [str(f) for f in missing_files.head(3)[1].values]
            logger.warning(
                f"{(~exists).sum()} files in provided labels file do not exist on disk; ignoring those files. Example: {example_missing}..."
            )

        values["labels"] = values["labels"][exists]

        return values

    @root_validator(skip_on_failure=True)
    def preprocess_labels(cls, values):
        """One hot encode and add splits."""
        logger.info("Preprocessing labels.")
        labels = values["labels"]

        # lowercase to facilitate subset checking
        labels["label"] = labels.label.str.lower()

        # one hot encoding
        labels = pd.get_dummies(labels.rename(columns={"label": "species"}), columns=["species"])

        # We validate that all the images exist prior to this, so once this assembles the set of classes,
        # we should have at least one example of each label and don't need to worry about filtering out classes
        # with missing examples.
        species_columns = labels.columns[labels.columns.str.contains("species_")]
        values["species_in_label_order"] = species_columns.to_list()

        indices = (
            labels[species_columns].idxmax(axis=1).apply(lambda x: species_columns.get_loc(x))
        )

        labels["label"] = indices

        # if no "split" column, set up train, val, and test split
        if "split" not in labels.columns:
            make_split(labels, values)

        values["labels"] = labels.reset_index()

        example_species = [
            species.replace("species_", "") for species in values["species_in_label_order"][:3]
        ]
        logger.info(
            f"Labels preprocessed. {len(values['species_in_label_order'])} species found: {example_species}..."
        )
        return values

    _validate_model_name_and_checkpoint = root_validator(allow_reuse=True, skip_on_failure=True)(
        validate_model_name_and_checkpoint
    )

    @root_validator(skip_on_failure=True)
    def validate_from_scratch(cls, values):
        from_scratch = values["from_scratch"]
        model_checkpoint = values["checkpoint"]
        if (from_scratch is False or from_scratch is False) and model_checkpoint is None:
            raise ValueError(
                "You must specify checkpoint if you don't want to start training from scratch."
            )
        return values

    @staticmethod
    def _validate_filepath(ix_path):
        ix, path = ix_path
        path = Path(path)
        return ix, path, path.exists() and path.stat().st_size > 0

preprocess_labels(values)

One hot encode and add splits.

Source code in zamba/images/config.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
@root_validator(skip_on_failure=True)
def preprocess_labels(cls, values):
    """One hot encode and add splits."""
    logger.info("Preprocessing labels.")
    labels = values["labels"]

    # lowercase to facilitate subset checking
    labels["label"] = labels.label.str.lower()

    # one hot encoding
    labels = pd.get_dummies(labels.rename(columns={"label": "species"}), columns=["species"])

    # We validate that all the images exist prior to this, so once this assembles the set of classes,
    # we should have at least one example of each label and don't need to worry about filtering out classes
    # with missing examples.
    species_columns = labels.columns[labels.columns.str.contains("species_")]
    values["species_in_label_order"] = species_columns.to_list()

    indices = (
        labels[species_columns].idxmax(axis=1).apply(lambda x: species_columns.get_loc(x))
    )

    labels["label"] = indices

    # if no "split" column, set up train, val, and test split
    if "split" not in labels.columns:
        make_split(labels, values)

    values["labels"] = labels.reset_index()

    example_species = [
        species.replace("species_", "") for species in values["species_in_label_order"][:3]
    ]
    logger.info(
        f"Labels preprocessed. {len(values['species_in_label_order'])} species found: {example_species}..."
    )
    return values

validate_image_files(values)

Validate and load image files.

Source code in zamba/images/config.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
@root_validator(skip_on_failure=True)
def validate_image_files(cls, values):
    """Validate and load image files."""
    logger.info("Validating image files exist")

    exists = process_map(
        cls._validate_filepath,
        (values["data_dir"] / values["labels"].filepath.path).items(),
        chunksize=max(
            1, len(values["labels"]) // 1000
        ),  # chunks can be large; should be fast operation
        total=len(values["labels"]),
    )

    file_existence = pd.DataFrame(exists).set_index(0)
    exists = file_existence[2]

    if not exists.all():
        missing_files = file_existence[~exists]
        example_missing = [str(f) for f in missing_files.head(3)[1].values]
        logger.warning(
            f"{(~exists).sum()} files in provided labels file do not exist on disk; ignoring those files. Example: {example_missing}..."
        )

    values["labels"] = values["labels"][exists]

    return values

validate_labels(values)

Validate and load labels

Source code in zamba/images/config.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
@root_validator(skip_on_failure=True)
def validate_labels(cls, values):
    """Validate and load labels"""
    logger.info("Validating labels")

    if isinstance(values["labels"], pd.DataFrame):
        pass
    elif values["labels"].suffix == ".json":
        with open(values["labels"], "r") as f:
            values["labels"] = cls.process_json_annotations(
                json.load(f), values["labels_format"]
            )
    else:
        values["labels"] = pd.read_csv(values["labels"])

    return values