Tensorflow Similarity Model (Part 2) — Finding similar items on Fashion MNIST

Finding nearest items on Fashion MNIST
High-level overview of the project implementation
  • Dataset preparation
  • Model training
  • Model evaluation (offline)
  • Building an interactive web-app for running product recommendation system

Step 1: Dataset preparation

Dataset preparation does the following task:

  • Download data to a local directory
  • Reshape train and test data into proper shape required for model training step
  • Select a subset of test images (in tensor format) and convert and save the images to a local directory (required for web app)

Step 2: Model training

Tensorflow Similarity provides different ways of loading and sampling data for model training. However, in this example, MultiShotMemorySampler from tensorflow_similarity.sampler is used so that it can demonstrate the ability to load any custom datasets.

x_train = np.load(os.path.join(data_dir, "train_images.npy"))    y_train = np.load(os.path.join(data_dir, "train_labels.npy"))    num_classes = len(np.unique(y_train))# data sampler that generates balanced batches from fashion-mnist datasetsampler = MultiShotMemorySampler(
classes_per_batch=num_classes, # make sure all classes are available in each batch
# build model architecture
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Rescaling(1 / 255)(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = MetricEmbedding(64)(x)
model = SimilarityModel(inputs, outputs)
model.index(x=x_train, y=y_train, data=x_train)

Step 3: Model evaluation (optional)

This step provides the ability to do manual evaluation. It provides the ability to randomly select a subset of test data and find the nearest train examples.

Step 4: Building an interactive web app

An interactive web app is developed using Dash and Plotly. The app displays a list of images used in model training and allows users to find nearest images from a selected image.

Selected Pullover
Nearest Pullovers found

My key takeaway

This is just a fun project that was developed to demonstrate the capability of Tensorflow Similarity library. However, the recommendation system can be generalized so that it can support different products and use cases. Also, Instead of training an offline dataset, the system can be designed so that there’s a continuous model training flow which allows adding new data into the model.



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store


Knowledge is power, but shared knowledge is far more powerful