Tensorflow Similarity Model (Part 2) — Finding similar items on Fashion MNIST

Ata-tech
4 min readMay 21, 2022

This is the second part of the Tensorflow Similarity Model series. The first part is available at the following link:

Finding nearest items on Fashion MNIST

The second part focuses on a practical experiment of using Tensorflow Similarity Search in which the search is performed on Fashion MNIST dataset. The project will train on the train set of Fashion MNIST dataset and the prediction will be done on the test set. A small web-app will also be implemented for the ease of evaluating and interacting with the model.

High-level overview of the project implementation

The project consists of 4 main steps:

  • Dataset preparation
  • Model training
  • Model evaluation (offline)
  • Building an interactive web-app for running product recommendation system

Step 1: Dataset preparation

Dataset preparation does the following task:

  • Download data to a local directory
  • Reshape train and test data into proper shape required for model training step
  • Select a subset of test images (in tensor format) and convert and save the images to a local directory (required for web app)

Step 2: Model training

Tensorflow Similarity provides different ways of loading and sampling data for model training. However, in this example, MultiShotMemorySampler from tensorflow_similarity.sampler is used so that it can demonstrate the ability to load any custom datasets.

x_train = np.load(os.path.join(data_dir, "train_images.npy"))    y_train = np.load(os.path.join(data_dir, "train_labels.npy"))    num_classes = len(np.unique(y_train))# data sampler that generates balanced batches from fashion-mnist datasetsampler = MultiShotMemorySampler(
x_train,
y_train,
classes_per_batch=num_classes, # make sure all classes are available in each batch
)

Next, a simple model architecture is built for the the Similarity model. However, with more complex dataset, a more sophisticated model architecture is recommended. The main idea is to add MetricEmbedding as the output of the model.

# build model architecture
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Rescaling(1 / 255)(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = MetricEmbedding(64)(x)
model = SimilarityModel(inputs, outputs)

Since this is a metric learning problem, a different loss function is required. Hence, MultiSimilarityLoss(distance=”cosine”) is used in this case.

After model training, index building step is required so that the train examples can be searchable.

model.index(x=x_train, y=y_train, data=x_train)

Step 3: Model evaluation (optional)

This step provides the ability to do manual evaluation. It provides the ability to randomly select a subset of test data and find the nearest train examples.

Step 4: Building an interactive web app

An interactive web app is developed using Dash and Plotly. The app displays a list of images used in model training and allows users to find nearest images from a selected image.

The way it works is that after an image is selected by a user, the tensor of the selected image will be sent to the trained model, and the model will return the indices of the items with the nearest distance to the test tensor.

As an example, a pullover from the test set was selected and all the returned nearest items also look like pullovers.

Selected Pullover
Nearest Pullovers found

My key takeaway

This is just a fun project that was developed to demonstrate the capability of Tensorflow Similarity library. However, the recommendation system can be generalized so that it can support different products and use cases. Also, Instead of training an offline dataset, the system can be designed so that there’s a continuous model training flow which allows adding new data into the model.

Traditionally, a recommendation system can be done by performing K-nearest neighbours (KNN) search (either on raw features or feature embeddings). However, these approaches require careful feature engineering since KNN is a separate step if feature embedding is used as input for the search problem. Additionally, there are issues with using KNN on high dimensional features such as images (curse of dimensionality https://en.wikipedia.org/wiki/Curse_of_dimensionality). Using Tensorflow Similarity will clearly help reduce the engineering effort since the Deep Learning model training and nearest neighbors search are incorporated in a single step. Moreover, feedback from nearest neighbors search will already be taken into training loss, and therefore, the expectation is that it will help improve the system performance. Also, having a simpler pipeline will help save significant engineering costs.

The code for the project is available at https://github.com/att288/tf-similarity-fashion-mnist. There’s docker deployment section which helps facilitate the running of the project. After running the docker command, the web application should be available at http://localhost:8050.

--

--

Ata-tech

Knowledge is power, but shared knowledge is far more powerful