Tensorflow Similarity Model (Part 2) — Finding similar items on Fashion MNIST

Finding nearest items on Fashion MNIST
High-level overview of the project implementation

Step 1: Dataset preparation

Step 2: Model training

x_train = np.load(os.path.join(data_dir, "train_images.npy"))    y_train = np.load(os.path.join(data_dir, "train_labels.npy"))    num_classes = len(np.unique(y_train))# data sampler that generates balanced batches from fashion-mnist datasetsampler = MultiShotMemorySampler(
classes_per_batch=num_classes, # make sure all classes are available in each batch
# build model architecture
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Rescaling(1 / 255)(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = MetricEmbedding(64)(x)
model = SimilarityModel(inputs, outputs)
model.index(x=x_train, y=y_train, data=x_train)

Step 3: Model evaluation (optional)

Step 4: Building an interactive web app

Selected Pullover
Nearest Pullovers found

My key takeaway



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store