buddi.models.components.losses#

Functions

classifier_loss_generator([weight, loss_fn, ...])

Classifier loss function generator.

kl_loss(y_true, y_pred)

KL divergence loss function.

kl_loss_generator([beta, agg_fn])

KL divergence loss function generator.

reconstr_loss_generator([weight, ...])

Reconstruction loss function generator.

unsupervised_dummy_loss_fn(y_true, y_pred)

Dummy loss function for unsupervised branch proportion estimator.

Classes

CategoricalCrossentropy([from_logits, ...])

Computes the crossentropy loss between the labels and predictions.

MeanSquaredError([reduction, name, dtype])

Computes the mean of squares of errors between labels and predictions.