FLAN: feature-wise latent additive neural models for biological applications

Abstract Motivation Interpretability has become a necessary feature for machine learning models deployed in critical scenarios, e.g. legal system, healthcare. In these situations, algorithmic decisions may have (potentially negative) long-lasting effects on the end-user affected by the decision. While deep learning models achieve impressive results, they often function as a black-box. Inspired by linear models, we propose a novel class of structurally constrained deep neural networks, which we call FLAN (Feature-wise Latent Additive Networks). Crucially, FLANs process each input feature separately, computing for each of them a representation in a common latent space. These feature-wise latent representations are then simply summed, and the aggregated representation is used for the prediction. These feature-wise representations allow a user to estimate the effect of each individual feature independently from the others, similarly to the way linear models are interpreted. Results We demonstrate FLAN on a series of benchmark datasets in different biological domains. Our experiments show that FLAN achieves good performances even in complex datasets (e.g. TCR-epitope binding prediction), despite the structural constraint we imposed. On the other hand, this constraint enables us to interpret FLAN by deciphering its decision process, as well as obtaining biological insights (e.g. by identifying the marker genes of different cell populations). In supplementary experiments, we show similar performances also on non-biological datasets. Code and data availability Code and example data are available at https://github.com/phineasng/flan_bio.

Additional results: biological datasets Additional figures for the single cell classification task Fig. 7. FLAN importance scores for the cellular clustering task. Shown are the importance scores averaged over the training set for 12 marker genes and the 6 most expressed genes over the whole dataset. Weights are normalized for each gene.       Additional figures for the TCR-epitope binding task Fig. 15. FLAN's amino acid preference for the TCR and the epitope sequences. Each bar represents the frequency that the amino acid is ranked among the top 10. The majority of the top 10 amino acids comes from the epitope. Additional figures for the image classification task Marker genes for the single cell classification task Table 7 displays the genes (referred to as marker genes) used by Zheng et al.
[31] to cluster the cells into the immune subpopulations. Since there exist many marker genes that characterize immune populations, Table 7 does not represent the ground truth and is only utilized to identify the immune subpopulations in this task. Table 8 shows the mean expression of the marker genes for the 7 clusters. The mean gene expressions and FLAN's gene importances are highly correlated. Noticeably, CD3D is expressed 22 times less in the CD4+ cluster than in the CD8+ class and 7 times less in Naive CD4+ comparing to the Naive CD8+ cluster. This might explain why FLAN fails to assign CD3D a high importance score in the CD4+ and Naive CD4+ clusters. FLAN only detects the biological signals in our data, and it doesn't have access to prior biological knowledge. Similarly, the mean expression of ID3 is less than 0.05% in the Naive CD4+ cluster, so FLAN might not consider ID3 very important due to its weak signal.  Additional results for TCR-epitope binding task

Sparse FLAN models
To enforce sparsity to the feature importances we trained FLAN with additional penalties, namely l1, l2 penalties for the network's parameters and a custom penalty penalizing the sum of the feature norms. The sparsity of the model trained with the latter penalty is presented in figure 18. However, we observed a significant drop in the performance. In the TCRsplit the ROC-AUC score dropped to 0,80 and the Balanced Accuracy dropped to 0,73. Our analysis showed that the more sparse the model becomes, the worst it performs. The trade-o↵ between model accuracy and the sparsity, indicates that FLAN requires as much information (e.g. amino acids) as possible to predict sequence binding.

FLAN with k-mers
Similarly to learning features from non-overlapping patches in image-based tasks, we could utilize k-mers of amino acids to learn features that will summarize information from a neighborhood of size k. In this setting, we utilize a convolutional network with kernel size and stride equal to 3 to learn the features. Figure 19 displays a typical example of a binding pair. We observe more sparsity when using 3mers, while the performance is still comparable to TITAN [33]. More specifically, the ROC-AUC score is 0,86 and 0,55 and the Balanced accuracy is 0,78 and 0,52 in the TCR and the strict split respectively.

Example-based interpretation for the image classification task
In this section, we discuss the third modality for interpreting FLAN, i.e. by examples. First, we separate the samples into 7 groups based on FLAN's predictions and perform K-Medoids with 3 clusters for each group. This results in 3 prototypes for each group, which can be used to identify the group's morphological characteristics. Figure 20 shows the results for the 7 groups. It seems that the rotation of the lesion and the contrast/brightness don't contribute much to the classification, while the color and the shape of the lesion may be important. For instance, the prototypes in the Melanoma, Melanocytic nevi and Vascular lesions classes are darker and with better defined shapes comparing to the rest. Another way to interpret FLAN with examples is to look for the nearest and the furthest neighbor in the latent space Z. Figure 21 displays the results for a sample predicted as melanocytic nevi. Its nearest neighbor is also classified as melanocytic nevi, and they share some morphological characteristics such as round shape and dark pink color. Its furthest neighbor has an irregular shape and a darker color.

Additional results: non-biological datasets Tabular Datasets
For benchmarking on tabular datasets, we follow Agarwal et al.
[23] and measure the performance of di↵erent models in terms of Area Under the Curve (AUC). Table 9 shows that these datasets are easy enough that a simple logistic regression model can perform well. The results on the adult and mammo datasets suggest that linearity is a good inductive bias for these tasks, since logistic regression is able to consistently outperform all the other (non-linear) models. On the other hand, in the heart dataset and, to a lesser degree, in the COMPAS dataset, it seems beneficial to include non-linearities and interactions. In particular, FLANs closely replicates the performance of more traditional feedforward networks (MLP). This might suggest that FLANs are similar to MLPs in terms of approximation capabilities.

Text Datasets
For text datasets, we used

Image Datasets
FLANs results (Table 11) on the MNIST dataset are comparable to established methods. Moreover, linear models (results not reported) do not achieve more than 94% test accuracy, providing further evidence to the ability of FLANs in implementing interactions without explicitly modeling them.
We further tested our model on the more di cult finegrained image classification dataset CUB-200-2011. FLANs do not achieve the same accuracy as other models. This might be explained by the fact that the models reported are pretrained on ImageNet [72] and further fine-tuned on this dataset. On the other hand, in our experiments, our top-performing models use only some layers of a pretrained ResNeXt [73] as part of the patch feature function i . We hypothesize that the inferior performance of FLANs is attributed to the fact that our model has to essentially learn the interactions from scratch, and CUB-200-2011 might be a too small of a dataset to e↵ectively learn this. Despite the lower accuracy, and given the relatively small size of the dataset (11.7k images split across 200 classes), we see our results as promising and a good basis for future investigations in large scale image recognition tasks that require interpretability.

COMPAS
We use the COMPAS dataset as a propedeutic example to show how to interpret FLANs. We can study the approximate e↵ect of single features separately, by applying the prediction network to the feature latent representation. In this case study, this is even easier since all the features are binarized. Figure 22 shows how the predicted risk changes if we switch a feature from 0 to 1. The results suggest that the risk is increased for criminals that have a high number of priors, are younger than 25, are Afro-American, or have already re-o↵ended in the past two years. Interestingly, the risk seems particularly decreased for criminals above the age of 45. To further validate these findings, we analyze the feature importances provided by our model. For each sample, we compute the importances as explained in Section 2.3.2 and then we average them over the training set. Figure 23 confirms that across the training set, the features mentioned above are the most discriminative ones. Previous analyses performed using interpretable models [13,23] reached similar conclusions.   The chosen hyperparameters for each experiment can be retrieved from the corresponding config.json file provided in the accompanying code. Details about the architectures used are also provided in the accompanying code. For the single cells task, we used a latent space of 24 dimensions and trained the model for 100 epochs and with batch size 64. For the TCR-Epitope binding task we used 32-dim embedding vectors to encode the amino acids and then mapped them into a 128-dim latent space. We then performed 100 runs with batch size 64 (with the best set of hyperparameters). For the TCR-Epitope binding task with triplets mentioned in B.5.2 we used 32-dim embedding vectors and a convolutional NN with kernel size and stride equal to 3. We also tried di↵erent paddings to find the best set of triplets. We trained again for 100 runs with batch size 64 (with the best set of hyperparameters). For the classification of the MedMNIST dataset we used patch size and stride equal to 4 and 400 filters and 256-dim latent space. We then trained the model for 100 epochs with batch size 128 (with the best set of hyperparameters) and the scheduler proposed by Loshchilov and Hutter [84]. For the text, tabular and image datasets mentioned in C we performed 50, 10 and 5 runs with the best set of hyperparameters, respectively.