Tensorflow Similarity Model (Part 1) — what is it and how is it different from a normal classifier?
Recently, Tensorflow team published an article about Similarity Model, which has a potential for wide practicalities. Therefore, I decided to take a deeper look into the library and its concept. This post keeps track of the findings and key takeaways I found interesting from the current version of the library.
Where can Similarity Model be useful?
Similarity Model provides the ability to search for related items. Its use cases are similar to the applicability of K-Nearest Neighbors search, but with the advantage of using DNN model as the input for the search domain, instead of raw data.
Also, Similarity Model solves the scalability issue since it uses Approximate Nearest Neighbors (ANN) search (sub-linear search), an adding new classes do not require model retraining.
Normal Deep Neural Network (DNN) Classifier vs Similarity Model
One might ask what is the difference between Similarity Model and normal DNN classification models. In short, Similarity Model can also be seen as a normal classification model (e.g. convolutional layers + activation layers + dense layers, etc.), but with different training objectives (different loss functions) and a different output handling.
Specifically, as illustrated in the below diagram, the architecture of the extraction layers of Similarity Model can be the same as the architecture used in normal DNN classifiers. However, the training objectives in Similarity Model is different from the normal classification models and the output format is also different.
Training objectives in Similarity Model
Similarity Model focuses on contrastive learning, which learns output embeddings such that similar samples stay close together, and far from dissimilar ones. Therefore, unlike loss functions in standard classification models, which focus on the predictability of the classes in a dataset, Similarity Model uses a different set of loss functions called contrastive losses (e.g. Ranking loss, Marging loss, Hinge loss, etc.). The main goal of contrastive losses are to minimize the distance between similar samples, and maximize the distance between dissimilar samples.
Output layer for normal classification models is usually a fully connected layer which can then be normalized using softmax or sigmoid function. The final output can be used directly for making predictions.
However, for Similarity Model, the output for the DNN component is a dense layer, which is the learned metric representation of the input. The learned metric embeddings will then be consumed by an Approximate Nearest Neighbors component to return the final prediction. As a result, the final prediction of a Similarity Model is the nearest labels in the training data, which can be more than 1 class if the number of pre-defined nearest neighbors is bigger than 1.
Performing Nearest Neighbor search on the extracted features from a normal classifier?
Since the architecture for feature extraction can be shared between normal classifier and Similarity Models, one might ask if it is possible to perform nearest neighbors search on the model feature extraction output instead of training the metric embeddings a Similarity Model?
The answer is yes, but there are a couple of caveats that worth pointing out.
Accuracy and speed
The distance between embedded points in Similarity Models use valid distance functions, which make it possible to perform Approximate NN. However, it is untrue in normal classifiers.
As a result, doing nearest neighbors search on the output of model feature extraction might require exact nearest neighbors search (which is expensive), and the results might not be as accurate. The main reason is because the main objective of the feature extractor in normal classifiers is not the same as the objectives in Similarity Model (in Similarity Model, loss functions focus on keeping similar samples close together, and dissimilar samples far from each other).
Adding new classes
Adding new classes in Similarity Model can simply done by feeding new classes’ data into the trained model, adding the metric embeddings to the look-up and update the ANN index, metadata (in diagram 1). However, in normal classifiers, it might require model retraining since the model is fine-tuned towards the current classes.
In the next post, I’ll write about an example implementation of using Similarity Model and compare its performance against a normal DNN classifier.
Thanks for reading!