Skip to content

darts_segmentation.training.viz

Visualization utilities for the training module.

Augmentation module-attribute

Augmentation = typing.Literal[
    "HorizontalFlip",
    "VerticalFlip",
    "RandomRotate90",
    "D4",
    "Blur",
    "RandomBrightnessContrast",
    "MultiplicativeNoise",
    "Posterize",
]

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

get_augmentation

get_augmentation(
    augment: list[
        darts_segmentation.training.augmentations.Augmentation
    ]
    | None,
    always_apply: bool = False,
) -> albumentations.Compose | None

Get augmentations for segmentation tasks.

Parameters:

  • augment (list[darts_segmentation.training.augmentations.Augmentation] | None) –

    List of augmentations to apply. If None or emtpy, no augmentations are applied. If not empty, augmentations are applied in the order they are listed. Available augmentations: - D4 (Combination of HorizontalFlip, VerticalFlip, and RandomRotate90) - Blur - RandomBrightnessContrast - MultiplicativeNoise - Posterize (quantization to reduce number of bits per channel)

  • always_apply (bool, default: False ) –

    If True, augmentations are always applied. This is useful for visualization/testing augmentations. Default is False.

Raises:

  • ValueError

    If an unknown augmentation is provided.

Returns:

  • albumentations.Compose | None

    A.Compose | None: A Compose object containing the augmentations. If no augmentations are provided, returns None.

Source code in darts-segmentation/src/darts_segmentation/training/augmentations.py
def get_augmentation(augment: list[Augmentation] | None, always_apply: bool = False) -> "A.Compose | None":  # noqa: C901
    """Get augmentations for segmentation tasks.

    Args:
        augment (list[Augmentation] | None): List of augmentations to apply.
            If None or emtpy, no augmentations are applied.
            If not empty, augmentations are applied in the order they are listed.
            Available augmentations:
                - D4 (Combination of HorizontalFlip, VerticalFlip, and RandomRotate90)
                - Blur
                - RandomBrightnessContrast
                - MultiplicativeNoise
                - Posterize (quantization to reduce number of bits per channel)
        always_apply (bool): If True, augmentations are always applied.
            This is useful for visualization/testing augmentations.
            Default is False.

    Raises:
        ValueError: If an unknown augmentation is provided.

    Returns:
        A.Compose | None: A Compose object containing the augmentations.
            If no augmentations are provided, returns None.

    """
    import albumentations as A  # noqa: N812

    if not isinstance(augment, list) or len(augment) == 0:
        return None

    # Replace HorizontalFlip, VerticalFlip, RandomRotate90 with D4
    if "HorizontalFlip" in augment and "VerticalFlip" in augment and "RandomRotate90" in augment:
        augment = [aug for aug in augment if aug not in ("HorizontalFlip", "VerticalFlip", "RandomRotate90")]
        augment.insert(0, "D4")

    transforms = []
    for aug in augment:
        match aug:
            case "D4":
                transforms.append(A.D4())
            case "HorizontalFlip":
                transforms.append(A.HorizontalFlip(p=1.0 if always_apply else 0.5))
            case "VerticalFlip":
                transforms.append(A.VerticalFlip(p=1.0 if always_apply else 0.5))
            case "RandomRotate90":
                transforms.append(A.RandomRotate90())
            case "Blur":
                transforms.append(A.Blur(p=1.0 if always_apply else 0.5))
            case "RandomBrightnessContrast":
                transforms.append(A.RandomBrightnessContrast(p=1.0 if always_apply else 0.5))
            case "MultiplicativeNoise":
                transforms.append(
                    A.MultiplicativeNoise(per_channel=True, elementwise=True, p=1.0 if always_apply else 0.5)
                )
            case "Posterize":
                # First convert to uint8, then apply posterization, then convert back to float32
                # * Note: This does only work for float32 images.
                transforms += [
                    A.FromFloat(dtype="uint8"),
                    A.Posterize(num_bits=6, p=1.0),
                    A.ToFloat(),
                ]
            case _:
                raise ValueError(f"Unknown augmentation: {aug}")
    return A.Compose(transforms)

plot_augmentations

plot_augmentations(
    x: torch.Tensor,
    augmentations: list[
        darts_segmentation.training.augmentations.Augmentation
    ],
    band_names: list[str],
) -> tuple[
    ultraplot.Figure, ultraplot.gridspec.SubplotGrid
]

Plot augmentations applied to a sample image.

Parameters:

Returns:

  • ultraplot.Figure

    matplotlib.figure.Figure: The figure object containing the plots.

  • ultraplot.gridspec.SubplotGrid

    ultraplot.gridspec.SubplotGrid: The axes of the plot.

Source code in darts-segmentation/src/darts_segmentation/training/viz.py
def plot_augmentations(
    x: torch.Tensor, augmentations: list[Augmentation], band_names: list[str]
) -> tuple[uplt.Figure, uplt.gridspec.SubplotGrid]:
    """Plot augmentations applied to a sample image.

    Args:
        x (torch.Tensor): Input tensor [N, C, H, W] (float).
        augmentations (list[Augmentation]): List of augmentations to apply.
        band_names (list[str]): List of band names corresponding to the channels in x.

    Returns:
        matplotlib.figure.Figure: The figure object containing the plots.
        ultraplot.gridspec.SubplotGrid: The axes of the plot.

    """
    compose = get_augmentation(augmentations)
    augmentations: dict[str, A.BasicTransform] = {aug: get_augmentation([aug], True) for aug in augmentations}

    rgb_idx = [band_names.index(band) for band in ["red", "green", "blue"]]

    nrows = 1 + len(augmentations) + 4
    ncols = x.shape[0]
    fig, axs = uplt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * 5, nrows * 5))
    for i in range(ncols):
        img = x[i, rgb_idx].permute(1, 2, 0).cpu().numpy()
        axs[0, i].imshow(img, vmin=0, vmax=0.1)
        axs[0, i].set_title("Original Image")
        for j, (aug_name, aug_fn) in enumerate(augmentations.items()):
            augmented = aug_fn(image=img)
            aug_img = augmented["image"]
            axs[j + 1, i].imshow(aug_img, vmin=0, vmax=0.1)
            axs[j + 1, i].set_title(f"Augmented: {aug_name}")

        # Apply full compose
        for j in range(4):
            augmented = compose(image=img)
            aug_img = augmented["image"]
            axs[j + 1 + len(augmentations), i].imshow(aug_img, vmin=0, vmax=0.1)
            axs[j + 1 + len(augmentations), i].set_title(f"Compose Augmentation {j + 1}")
    return fig, axs

plot_sample

plot_sample(
    x: torch.Tensor,
    y: torch.Tensor,
    y_pred: torch.Tensor,
    band_names: list[str],
) -> tuple[
    matplotlib.pyplot.Figure,
    dict[str, matplotlib.pyplot.Axes],
]

Plot a single sample with the input, the ground truth and the prediction.

This function does a few expections on the input: - The input is expected to be normalized to 0-1. - The prediction is expected to be converted from logits to prediction. - The target is expected to be a int or long tensor with values of: 0 (negative class) 1 (positive class) and 2 (invalid pixels).

Parameters:

  • x (torch.Tensor) –

    The input tensor [C, H, W] (float).

  • y (torch.Tensor) –

    The ground truth tensor [H, W] (int).

  • y_pred (torch.Tensor) –

    The prediction tensor [H, W] (float).

  • band_names (list[str]) –

    The combinations of the input bands.

Returns:

  • tuple[matplotlib.pyplot.Figure, dict[str, matplotlib.pyplot.Axes]]

    tuple[Figure, dict[str, Axes]]: The figure and the axes of the plot.

Source code in darts-segmentation/src/darts_segmentation/training/viz.py
def plot_sample(
    x: torch.Tensor, y: torch.Tensor, y_pred: torch.Tensor, band_names: list[str]
) -> tuple[plt.Figure, dict[str, plt.Axes]]:
    """Plot a single sample with the input, the ground truth and the prediction.

    This function does a few expections on the input:
    - The input is expected to be normalized to 0-1.
    - The prediction is expected to be converted from logits to prediction.
    - The target is expected to be a int or long tensor with values of:
        0 (negative class)
        1 (positive class) and
        2 (invalid pixels).

    Args:
        x (torch.Tensor): The input tensor [C, H, W] (float).
        y (torch.Tensor): The ground truth tensor [H, W] (int).
        y_pred (torch.Tensor): The prediction tensor [H, W] (float).
        band_names (list[str]): The combinations of the input bands.

    Returns:
        tuple[Figure, dict[str, Axes]]: The figure and the axes of the plot.

    """
    x = x.cpu()
    y = y.cpu()
    y_pred = y_pred.detach().cpu()

    # Make y class 2 invalids (replace 2 with nan)
    x = x.where(y != 2, torch.nan)
    y_pred = y_pred.where(y != 2, torch.nan)
    y = y.where(y != 2, torch.nan)

    # pred == 0, y == 0 -> 0 (true negative)
    # pred == 1, y == 0 -> 1 (false positive)
    # pred == 0, y == 1 -> 2 (false negative)
    # pred == 1, y == 1 -> 3 (true positive)
    classification_labels = (y_pred > 0.5).int() + y * 2
    classification_labels = classification_labels.where(classification_labels != 0, torch.nan)

    # Calculate f1 and iou
    true_positive = (classification_labels == 3).sum()
    false_positive = (classification_labels == 1).sum()
    false_negative = (classification_labels == 2).sum()
    true_negative = (classification_labels == 0).sum()
    acc = (true_positive + true_negative) / (true_positive + true_negative + false_positive + false_negative)
    f1 = 2 * true_positive / (2 * true_positive + false_positive + false_negative)
    iou = true_positive / (true_positive + false_positive + false_negative)

    cmap = mcolors.ListedColormap(["#cd43b2", "#3e0f2f", "#6cd875"])
    fig, axs = plt.subplot_mosaic(
        # [["rgb", "rgb", "ndvi", "tcvis", "stats"], ["rgb", "rgb", "pred", "slope", "elev"]],
        [["rgb", "rgb", "pred", "tcvis"], ["rgb", "rgb", "ndvi", "slope"], ["none", "stats", "stats", "stats"]],
        # layout="constrained",
        figsize=(11, 8),
    )

    # Disable none plot
    axs["none"].axis("off")

    # RGB Plot
    ax_rgb = axs["rgb"]
    # disable axis
    ax_rgb.axis("off")
    is_rgb = "red" in band_names and "green" in band_names and "blue" in band_names
    if is_rgb:
        red_band = band_names.index("red")
        green_band = band_names.index("green")
        blue_band = band_names.index("blue")
        rgb = x[[red_band, green_band, blue_band]].transpose(0, 2).transpose(0, 1)
        ax_rgb.imshow(rgb ** (1 / 1.4))
        ax_rgb.set_title(f"Acc: {acc:.1%} F1: {f1:.1%} IoU: {iou:.1%}")
    else:
        # Plot empty with message that RGB is not provided
        ax_rgb.set_title("No RGB values are provided!")
    ax_rgb.imshow(classification_labels, alpha=0.6, cmap=cmap, vmin=1, vmax=3)
    # Add a legend
    patches = [
        mpatches.Patch(color="#6cd875", label="True Positive"),
        mpatches.Patch(color="#3e0f2f", label="False Negative"),
        mpatches.Patch(color="#cd43b2", label="False Positive"),
    ]
    ax_rgb.legend(handles=patches, loc="upper left")

    # NDVI Plot
    ax_ndvi = axs["ndvi"]
    ax_ndvi.axis("off")
    is_ndvi = "ndvi" in band_names
    if is_ndvi:
        ndvi_band = band_names.index("ndvi")
        ndvi = x[ndvi_band]
        ax_ndvi.imshow(ndvi, vmin=0, vmax=1, cmap="RdYlGn")
        ax_ndvi.set_title("NDVI")
    else:
        # Plot empty with message that NDVI is not provided
        ax_ndvi.set_title("No NDVI values are provided!")

    # TCVIS Plot
    ax_tcv = axs["tcvis"]
    ax_tcv.axis("off")
    is_tcvis = "tc_brightness" in band_names and "tc_greenness" in band_names and "tc_wetness" in band_names
    if is_tcvis:
        tcb_band = band_names.index("tc_brightness")
        tcg_band = band_names.index("tc_greenness")
        tcw_band = band_names.index("tc_wetness")
        tcvis = x[[tcb_band, tcg_band, tcw_band]].transpose(0, 2).transpose(0, 1)
        ax_tcv.imshow(tcvis)
        ax_tcv.set_title("TCVIS")
    else:
        ax_tcv.set_title("No TCVIS values are provided!")

    # Statistics Plot
    ax_stat = axs["stats"]
    if (y == 1).sum() > 0:
        n_bands = x.shape[0]
        n_pixel = x.shape[1] * x.shape[2]
        x_flat = x.flatten().cpu()
        y_flat = y.flatten().repeat(n_bands).cpu()
        bands = list(itertools.chain.from_iterable([band_names[i]] * n_pixel for i in range(n_bands)))
        plot_data = pd.DataFrame({"x": x_flat, "y": y_flat, "band": bands})
        if len(plot_data) > 50000:
            plot_data = plot_data.sample(50000)
        plot_data = plot_data.sort_values("band")
        sns.violinplot(
            x="x",
            y="band",
            hue="y",
            data=plot_data,
            split=True,
            inner="quart",
            fill=False,
            palette={1: "g", 0: ".35"},
            density_norm="width",
            ax=ax_stat,
        )
        ax_stat.set_title("Band Statistics")
    else:
        ax_stat.set_title("No positive labels in this sample!")
        ax_stat.axis("off")

    # Prediction Plot
    ax_mask = axs["pred"]
    ax_mask.imshow(y_pred, vmin=0, vmax=1)
    ax_mask.axis("off")
    ax_mask.set_title("Model Output")

    # Slope Plot
    ax_slope = axs["slope"]
    ax_slope.axis("off")
    is_slope = "slope" in band_names
    if is_slope:
        slope_band = band_names.index("slope")
        slope = x[slope_band]
        ax_slope.imshow(slope, cmap="cividis")
        # Add TPI as contour lines
        is_rel_elev = "relative_elevation" in band_names
        if is_rel_elev:
            rel_elev_band = band_names.index("relative_elevation")
            rel_elev = x[rel_elev_band]
            cs = ax_slope.contour(rel_elev, [0], colors="red", linewidths=0.3, alpha=0.6)
            ax_slope.clabel(cs, inline=True, fontsize=5, fmt="%.1f")

        ax_slope.set_title("Slope")
    else:
        # Plot empty with message that slope is not provided
        ax_slope.set_title("No Slope values are provided!")

    # Relative Elevation Plot
    # rel_elev_band = band_names.index("relative_elevation")
    # rel_elev = x[rel_elev_band]
    # ax_rel_elev = axs["elev"]
    # ax_rel_elev.imshow(rel_elev, cmap="cividis")
    # ax_rel_elev.axis("off")
    # ax_rel_elev.set_title("Relative Elevation")

    return fig, axs

plot_training_data_distribution

plot_training_data_distribution(
    train_metadata: geopandas.GeoDataFrame,
    val_metadata: geopandas.GeoDataFrame | None,
    test_metadata: geopandas.GeoDataFrame | None,
    name: str,
) -> tuple[
    matplotlib.pyplot.Figure,
    dict[str, matplotlib.pyplot.Axes],
]

Plot the distribution of training data by region on a polar projection.

Parameters:

  • train_metadata (geopandas.GeoDataFrame) –

    GeoDataFrame containing training metadata.

  • val_metadata (geopandas.GeoDataFrame | None) –

    GeoDataFrame containing validation metadata.

  • test_metadata (geopandas.GeoDataFrame | None) –

    GeoDataFrame containing test metadata.

  • name (str) –

    Name of the dataset or experiment for the plot title.

Returns:

  • tuple[matplotlib.pyplot.Figure, dict[str, matplotlib.pyplot.Axes]]

    tuple[plt.Figure, plt.Axes]: The figure and axes of the plot.

Source code in darts-segmentation/src/darts_segmentation/training/viz.py
def plot_training_data_distribution(
    train_metadata: gpd.GeoDataFrame,
    val_metadata: gpd.GeoDataFrame | None,
    test_metadata: gpd.GeoDataFrame | None,
    name: str,
) -> tuple[plt.Figure, dict[str, plt.Axes]]:
    """Plot the distribution of training data by region on a polar projection.

    Args:
        train_metadata (gpd.GeoDataFrame): GeoDataFrame containing training metadata.
        val_metadata (gpd.GeoDataFrame | None): GeoDataFrame containing validation metadata.
        test_metadata (gpd.GeoDataFrame | None): GeoDataFrame containing test metadata.
        name (str): Name of the dataset or experiment for the plot title.

    Returns:
        tuple[plt.Figure, plt.Axes]: The figure and axes of the plot.

    """
    # Aggregate by sample_id to get counts of not-empty tiles for train and test
    # Get centroids of the aggregated geometries
    train_metadata["not-empty"] = ~train_metadata["empty"]
    train_sample_data = train_metadata[["sample_id", "not-empty", "geometry"]].dissolve(by="sample_id", aggfunc="sum")
    train_centroids = train_sample_data.geometry.centroid
    if val_metadata is not None:
        val_metadata["not-empty"] = ~val_metadata["empty"]
        val_sample_data = val_metadata[["sample_id", "not-empty", "geometry"]].dissolve(by="sample_id", aggfunc="sum")
        val_centroids = val_sample_data.geometry.centroid
    if test_metadata is not None:
        test_metadata["not-empty"] = ~test_metadata["empty"]
        test_sample_data = test_metadata[["sample_id", "not-empty", "geometry"]].dissolve(by="sample_id", aggfunc="sum")
        test_centroids = test_sample_data.geometry.centroid

    # Create figure with NorthPolarStereo projection
    fig, axs = plt.subplot_mosaic(
        [["map", "map", "map", "train-dist"], ["map", "map", "map", "val-dist"], ["map", "map", "map", "test-dist"]],
        layout="constrained",
        figsize=(12, 8),
        per_subplot_kw={"map": {"projection": ccrs.NorthPolarStereo()}},
    )

    # Set the extent to limit to 55°N latitude (circular boundary)
    axs["map"].set_extent([-180, 180, 55, 90], ccrs.PlateCarree())

    # Add map features
    axs["map"].add_feature(cfeature.LAND, facecolor="lightgray", alpha=0.3)
    axs["map"].add_feature(cfeature.OCEAN, facecolor="lightblue", alpha=0.3)
    axs["map"].add_feature(cfeature.COASTLINE, linewidth=0.5)
    axs["map"].add_feature(cfeature.BORDERS, linewidth=0.3, linestyle=":")

    # Add gridlines
    axs["map"].gridlines(draw_labels=True, linewidth=0.5, alpha=0.5, linestyle="--")

    # Determine common vmax for consistent color scaling
    vmax = max(
        train_sample_data["not-empty"].max(),
        (val_sample_data["not-empty"].max() if val_metadata is not None else 0),
        (test_sample_data["not-empty"].max() if test_metadata is not None else 0),
    )

    # Plot the training regions with circles
    train_scatter = axs["map"].scatter(
        train_centroids.x,
        train_centroids.y,
        c=train_sample_data["not-empty"],
        cmap="YlGnBu",
        s=120,
        alpha=0.7,
        transform=ccrs.PlateCarree(),
        edgecolors="black",
        linewidths=0.5,
        vmin=0,
        vmax=vmax,
        marker="o",  # Circle for training data
        label="Training",
    )

    # Plot the training regions with circles
    if val_metadata is not None:
        axs["map"].scatter(
            val_centroids.x,
            val_centroids.y,
            c=val_sample_data["not-empty"],
            cmap="YlGnBu",
            s=80,
            alpha=0.7,
            transform=ccrs.PlateCarree(),
            edgecolors="black",
            linewidths=0.5,
            vmin=0,
            vmax=vmax,
            marker="*",  # Circle for training data
            label="Validation",
        )

    # Plot the test regions with triangles
    if test_metadata is not None:
        axs["map"].scatter(
            test_centroids.x,
            test_centroids.y,
            c=test_sample_data["not-empty"],
            cmap="YlGnBu",
            s=80,
            alpha=0.7,
            transform=ccrs.PlateCarree(),
            edgecolors="black",
            linewidths=0.5,
            vmin=0,
            vmax=vmax,
            marker="^",  # Triangle for test data
            label="Test",
        )

    # Add colorbar
    cbar = plt.colorbar(train_scatter, ax=axs["map"], shrink=0.6, pad=0.05)
    cbar.set_label("Number of Patches with Data", rotation=270, labelpad=20, fontsize=12)

    # Add legend for train/test split
    legend = axs["map"].legend(
        loc="lower left", frameon=True, fancybox=True, shadow=True, fontsize=11, title="Data Split"
    )
    legend.get_frame().set_alpha(0.9)

    # Create circular boundary at 55°N
    theta = np.linspace(0, 2 * np.pi, 100)
    verts = np.vstack([np.sin(theta), np.cos(theta)]).T
    circle = mpath.Path(verts * 0.5 + 0.5)
    axs["map"].set_boundary(circle, transform=axs["map"].transAxes)

    axs["map"].set_title(f"Training Data Distribution by Region ({name})", fontsize=14, fontweight="bold", pad=20)

    sns.histplot(
        train_metadata,
        y="region",
        hue="not-empty",
        multiple="stack",
        ax=axs["train-dist"],
        palette=["#7f8c8d", "#27ae60"],  # Gray for w/o RTS, Green for w/ RTS
    )
    axs["train-dist"].set_title("Training Set Distribution by Region")
    axs["train-dist"].legend(labels=["w RTS", "w/o RTS"])
    axs["train-dist"].set_ylabel("")
    axs["train-dist"].set_xlabel("Number of Patches")

    if val_metadata is not None:
        sns.histplot(
            val_metadata,
            y="region",
            hue="not-empty",
            multiple="stack",
            ax=axs["val-dist"],
            palette=["#7f8c8d", "#27ae60"],  # Gray for w/o RTS, Green for w/ RTS
        )
        axs["val-dist"].set_title("Validation Set Distribution by Region")
        axs["val-dist"].legend(labels=["w/ RTS", "w/o RTS"])
        axs["val-dist"].set_ylabel("")
        axs["val-dist"].set_xlabel("Number of Patches")

    if test_metadata is not None:
        sns.histplot(
            test_metadata,
            y="region",
            hue="not-empty",
            multiple="stack",
            ax=axs["test-dist"],
            palette=["#7f8c8d", "#27ae60"],  # Gray for w/o RTS, Green for w/ RTS
        )
        axs["test-dist"].set_title("Test Set Distribution by Region")
        axs["test-dist"].legend(labels=["w/ RTS", "w/o RTS"])
        axs["test-dist"].set_ylabel("")
        axs["test-dist"].set_xlabel("Number of Patches")

    # fig.tight_layout()
    return fig, axs