Build Android app for custom object image classification

Build Android app for custom object image classification

(Using Google Colab)

Objective: Build Android app for custom object image classification using Google Colab.

Train a Machine Learning model for custom object image classification using TensorFlow 2.x, convert it to a TFLite model, and finally deploy it on mobile devices using the sample TFLite image classification app from TensorFlow’s GitHub.


  • Collect the dataset of images.
  • Setup the environment, mount drive, and create a folder for backup on drive.
  • Pre-process the model, compile the model and finally train the model.
  • Evaluate the model.
  • Export the model
  • Create TFLite model for Android
  • Download sample image classification app from TensorFlow and adjust for your custom model.

I am training a model for mask image classification. This is done in 15 steps mentioned in the section below. The first 11 steps are the same as in the Train an ML model for custom object Image Classification using Google Colab tutorial. After that, we convert it to TFLite and deploy it on the android device to build our own Android app for custom object image classification.

  1. Import dependencies and libraries
  2. Mount drive and link your folder
  3. Create custom_ic folder in your google drive
  4. Upload your dataset to your drive and unzip it in the Content directory
  5. Select Variables
  6. Preprocess data
  7. Compile model
  8. Load Tensorboard
  9. Train the model
  10. Check the predictions
  11. Export your model
  12. Test the saved model by reloading
  13. Convert the saved model to TFLite model
  14. To create metadata for the TFLite model (optional)
  15. Setup your custom image classifier model in Android Studio


  • Open my Colab notebook on your browser.
  • Click on File in the menu bar and click on Save a copy in drive. This will open a copy of my Colab notebook on your browser which you can now use.
  • Next, once you have opened the copy of my notebook and are connected to the Google Colab VM, click on Runtime in the menu bar and click on Change runtime type. Select GPU and click on save.
Image for post


1) Setup (Import dependencies and libraries)

from __future__ import absolute_import, division, print_function, unicode_literals

import matplotlib.pylab as plt
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers
import numpy as np

import keras as ker


2) Mount drive and link your folder

# mount drive

from google.colab import drive

# this creates a symbolic link so that now the path /content/gdrive/My Drive/ is equal to /mydrive

!ln -s /content/gdrive/My Drive/ /mydrive
!ls /mydrive

3) Create custom_ic folder in your google drive

Create a folder named custom_ic for your custom image classification files in your google drive.

4) Upload your dataset to your drive and unzip it in the Content directory

# Upload dataset to the custom_ic folder in your drive and copy it to the content dir
!cp /mydrive/custom_ic/ /content

# Unzip the dataset
!unzip -qq ""

Simple transfer learning

5) Select Variables

Select Model

#@title Enter Variables

Type_of_model = 'efficient' #@param ["efficient", "mobilenetv2", "inceptionv3", "inceptionResnetv2"]

How to Choose Model


# Mobilenet : less number of classes and needed low size model in less time.
# Inception : More accuracy than Mobilenet inceptionResnet > inception v3 but v3 has less number so small size.
# if need small size more accuracy go for v3, if less size issue then inceptionResnet

# Efficient : More Size, More Accuracy, More time to train. 

# Advice : Efficient is STATE OF THE ART,  USE IT 

# There are other State of art models but their packages are not available now

# Accuracy : 
# EfficientNet > inceptionResnetv2 > InceptionNet >= Mobilenet  

# Size : 
# Mobilenet < inceptionNet < inceptionResnetv2 < EfficientNet

model_link = {"efficient":"", 
              "mobilenetv2" : "", 
              "inceptionv3" : "", 
              "inceptionResnetv2" : ""}

model_shape = {"efficient":(300,300), 
              "mobilenetv2" : (224,224), 
              "inceptionv3" : (300,300), 
              "inceptionResnetv2" : (300,300)

6) Preprocess data

IMAGE_SHAPE = model_shape[Type_of_model]
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255, validation_split=0.3)

train_gen = image_generator.flow_from_directory('obj', target_size=IMAGE_SHAPE,  subset='training')
val_gen = image_generator.flow_from_directory('obj', target_size=IMAGE_SHAPE, subset='validation')

for image_batch, label_batch in train_gen:
  print("Image batch shape: ", image_batch.shape)
  print("Label batch shape: ", label_batch.shape)
for image_batch, label_batch in val_gen:
  print("Image batch shape: ", image_batch.shape)
  print("Label batch shape: ", label_batch.shape)

feature_extractor_url= model_link[Type_of_model] 

feature_extractor_layer = hub.KerasLayer(feature_extractor_url,

feature_batch = feature_extractor_layer(image_batch)

feature_extractor_layer.trainable = False

7) Compile model

model = tf.keras.Sequential([
  layers.Dense(2, activation='softmax')


predictions = model(image_batch)

#Use compile to configure the training process:


Set model directory and checkpoint

import os

model_dir = "/mydrive/custom_ic/training/"
checkpoint_dir = "weights"

os.makedirs(model_dir, exist_ok=True)
# os.makedirs(checkpoint_dir, exist_ok=True)

class CollectBatchStats(tf.keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint

checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', verbose=0)
tensorboard = TensorBoard(log_dir=model_dir)

8) Load tensorboard

%load_ext tensorboard
%tensorboard --logdir "training/"

9) Train the model

%steps_per_epoch = np.ceil(train_gen.samples/train_gen.batch_size)

batch_stats_callback = CollectBatchStats()

# If you want to save every checkpoint change callbacks = [batch_stats_callback, tensorboard, checkpoint]

history = model.fit_generator(train_gen, validation_data=val_gen, epochs=11,
                              callbacks = [batch_stats_callback, tensorboard])
#Press (Ctrl + Shift + i) . Go to console. Paste the following code and press Enter.

function ClickConnect(){
  .querySelector('#top-toolbar > colab-connect-button')

Check training progress

Now after, even just a few training iterations, we can already see that the model is making progress on the task.

Plot loss graph

plt.xlabel("Training Steps")

Plot accuracy graph

plt.xlabel("Training Steps")

10) Check the predictions

To redo the plot from before, first get the ordered list of class names:

import numpy as np
class_names = sorted(val_gen.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])

Run the image batch through the model and convert the indices to class names.

predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

Plot the result

label_id = np.argmax(label_batch, axis=-1)

for n in range(30):
  color = "green" if predicted_id[n] == label_id[n] else "red"
  plt.title(predicted_label_batch[n].title(), color=color)
_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")

11) Export your model

Now that you’ve trained the model, export it as a saved model:

# Saving as tensorflow file

export_path = "saved_models", save_format='tf')

import tensorflow as tf
sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(log_device_placement=True))
my_reloaded_model = tf.keras.models.load_model(('/mydrive/custom_ic/image_classification_model.h5'),custom_objects={'KerasLayer':hub.KerasLayer})
from tensorflow.python.client import device_lib

12) Test the saved model by reloading

Reload the model

export_path = "saved_models"
reloaded = tf.keras.models.load_model(export_path)

Run inference to test the model

import cv2
from google.colab.patches import cv2_imshow

# Enter your labels here


# Read an image
image = cv2.imread('/mydrive/mask_test_images/image1.jpg')

crop_img = cv2.resize(image, IMAGE_SHAPE)
crop_img = crop_img/255
crop_img = crop_img.reshape(1,IMAGE_SHAPE[0],IMAGE_SHAPE[1],3)

result = reloaded.predict(crop_img)
print(' result',result)
predicted_id = np.argmax(result, axis=-1)


13) Convert the saved model to TFLite model

# Convert the model. TFLITE
converter = tf.lite.TFLiteConverter.from_saved_model(export_path)
tflite_model = converter.convert()

# Save the TF Lite model.
open("/mydrive/custom_ic/model.tflite", "wb").write(tflite_model)

# Write the labels file
labels = 'n'.join(sorted(train_gen.class_indices.keys()))
with open('/mydrive/custom_ic/labels.txt', 'w') as f:
from tensorflow.python.client import device_lib

14) To create metadata for the TFLite model (optional)

( If using with lib_task_api implementation, metadata is required to be attached with the TFLite model. However, for lib_support implementation it is not required. If you’re a beginner, use lib_support to get started. You can read more about these 2 implementations in Step 15 )

First,install tflite-support

!pip install -q tflite-support

Create folders model_without_metadata and model_with_metadata to differentiate. Next, move the created TFLite model in step 12 into the model_without_metadata folder

Current working directory is /mydrive/custom_ic

%cd /mydrive/custom_ic

!mkdir model_without_metadata
!mkdir model_with_metadata

!mv *.tflite model_without_metadata/

Define metadata

import os
import tensorflow as tf
from absl import flags

from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb

# This is where we will export a new .tflite model file with metadata, and a .json file with metadata info
EXPORT_DIR = "model_with_metadata"

class MetadataPopulatorForGANModel(object):
  """Populates the metadata for the MAsk detector model."""

  def __init__(self, model_file):
    self.model_file = model_file
    self.metadata_buf = None

  def populate(self):
    """Creates metadata and then populates it for a style transfer model."""
  def _create_metadata(self):
    """Creates the metadata for the MaskDetector model."""

    # Creates model info.
    model_meta = _metadata_fb.ModelMetadataT() = "Image Classifier" 
    model_meta.description = ("MASK DETECTION.")
    model_meta.version = "v1" = "TensorFlow"
    model_meta.license = ("Apache License. Version 2.0 "

    # Creates info for the input, normal image.
    input_image_meta = _metadata_fb.TensorMetadataT() = "source_image"
    # if self.model_type=="other":
    input_image_meta.description = (
            "The expected image is 300 x 300, with three channels "
            "(blue, red, and green) per pixel. Each value in the tensor is between"
            " -1 and 1.")
    # elif self.model_type=="fp16":
    #     input_image_meta.description = (
    #         "The expected image is 300 x 300, with three channels "
    #         "(blue, red, and green) per pixel. Each value in the tensor is between"
    #         " -1 and 1.")
    input_image_meta.content = _metadata_fb.ContentT()
    input_image_meta.content.contentProperties = (
    input_image_meta.content.contentProperties.colorSpace = (
    input_image_meta.content.contentPropertiesType = (
    input_image_normalization = _metadata_fb.ProcessUnitT()
    input_image_normalization.optionsType = (
    input_image_normalization.options = _metadata_fb.NormalizationOptionsT()
    input_image_normalization.options.mean = [127.5]
    input_image_normalization.options.std = [127.5]
    input_image_meta.processUnits = [input_image_normalization]
    input_image_stats = _metadata_fb.StatsT()
    input_image_stats.max = [1.0]
    input_image_stats.min = [-1.0]
    input_image_meta.stats = input_image_stats

    # Creates output info, cartoonized image
    output_image_meta = _metadata_fb.TensorMetadataT() = "output_image"
    output_image_meta.description = "Image detected."
    output_image_meta.content = _metadata_fb.ContentT()
    output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
    output_image_meta.content.contentProperties.colorSpace = (
    output_image_meta.content.contentPropertiesType = (
    output_image_normalization = _metadata_fb.ProcessUnitT()
    output_image_normalization.optionsType = (
    output_image_normalization.options = _metadata_fb.NormalizationOptionsT()
    output_image_normalization.options.mean = [-1]
    output_image_normalization.options.std = [0.00784313] # 1/127.5
    output_image_meta.processUnits = [output_image_normalization]
    output_image_stats = _metadata_fb.StatsT()
    output_image_stats.max = [255.0]
    output_image_stats.min = [0.0]
    output_image_meta.stats = output_image_stats

    # Creates subgraph info.
    subgraph = _metadata_fb.SubGraphMetadataT()
    subgraph.inputTensorMetadata = [input_image_meta] 
    subgraph.outputTensorMetadata = [output_image_meta] 
    model_meta.subgraphMetadata = [subgraph]

    b = flatbuffers.Builder(0)
    self.metadata_buf = b.Output()

  def _populate_metadata(self):
    """Populates metadata to the model file."""
    populator = _metadata.MetadataPopulator.with_model_file(self.model_file)

  def populate_metadata(model_file):
  """Populates the metadata using the populator specified.
      model_file: valid path to the model file.
      model_type: a type defined in StyleTransferModelType .

  # Populates metadata for the model.
  model_file_basename = os.path.basename(model_file)
  export_path = os.path.join(EXPORT_DIR, model_file_basename), export_path, overwrite=True)

  populator = MetadataPopulatorForGANModel(export_path) 

  # Displays the metadata that was just populated into the tflite model.
  displayer = _metadata.MetadataDisplayer.with_model_file(export_path)
  export_json_file = os.path.join(
      os.path.splitext(model_file_basename)[0] + ".json")
  json_file = displayer.get_metadata_json()
  with open(export_json_file, "w") as f:
  print("Finished populating metadata and associated file to the model:")
  print("The metadata json file has been saved to:")
                   os.path.splitext(model_file_basename)[0] + ".json"))

Choose the type of model you are creating (float or int etc) and finally populate the TFLite model with metadata.

quantization = "fp16" #@param ["dr", "int8", "fp16"]
#tflite_model_path = f"whitebox_cartoon_gan_{quantization}.tflite" 
tflite_model_path = f"model.tflite" 
MODEL_FILE = "//mydrive/custom_ic/model_without_metadata/{}".format(tflite_model_path)

The TFLite model with metadata is created inside the model_with_metadata folder along with the .json file. Download both these files and the labels.txt file to use in Android Studio for the Image Classification app.

15) Download the TFLite model and adjust the TFLite Image Classification sample app with your custom model

  • Download the TensorFlow Lite examples archive from here and unzip it. You will find an image classification app inside


  • Next, copy the model.tflite model with metadata and the labels.txt file inside the assets folder in the object detection Android app.


IMPORTANT: Switch between inference solutions (Task library vs Support Library)

This Image Classification Android reference app demonstrates two implementation solutions:

(1) lib_task_api that leverages the out-of-box API from the TensorFlow Lite Task Library;

(2) lib_support that creates the custom inference pipleline using the TensorFlow Lite Support Library.

The build.gradle inside app folder shows how to change flavorDimensions "tfliteInference" to switch between the two solutions.

Inside Android Studio, you can change the build variant to whichever one you want to build and run—just go to Build > Select Build Variant and select one from the drop-down menu. See configure product flavors in Android Studio for more details.

For gradle CLI, running ./gradlew build can create APKs under app/build/outputs/apk for both solutions.

Note: If you simply want the out-of-box API to run the app, we recommend lib_task_api for inference. If you want to customize your own models and control the detail of inputs and outputs, it might be easier to adapt your model inputs and outputs by using lib_support.

The file download.gradle directs gradle to download the two models used in the example, placing them into assets.

**For beginners, you can choose “lib_support”. The default implementation is “lib_task_api”. The “lib_support” implementation does not require any metadata while the lib_task_api implementation requires metadata. (I have shown how to get the TFLite model with metadata in step 14.

To choose a build variant, select the app module in the project manager inside Android Studio, go to Build in the menu bar, and click on Select Build Variant. You will see a window pop up where you can choose your variant as shown below.

  • Next, make changes in the code as mentioned below.

The changes to the code are as follows:

  • Edit the gradle build module file. Open the “build.gradle” file $TF_EXAMPLES/lite/examples/image_classification/android/app/build.gradle to comment out the apply from: ‘download_model.gradle‘ which basically downloads the default image classification app’s TFLite model and overwrites your assets.
// apply from:'download_model.gradle'

Select Build -> Make Project and check that the project builds successfully. You will need Android SDK configured in the settings. You’ll need at least SDK version 23. The build.gradle file will prompt you to download any missing libraries.

  • Next, if your model is named model.tflite , and your labels file labels.txt, the example will use them automatically as long as they’ve been properly copied into the base assets directory. To confirm, open up the $TF_EXAMPLES/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ file in a text editor or inside Android Studio itself and find the definition of getModelPath(). Verify that it points to your model file: “model.tflite“. Next, verify that the path to labels getLabelPath() points to your label file “labels.txt“.
  • If you also have the quantized TFLite model for MobileNet and both the float and quantized models for EfficientNet we can do the same step as above and make the getModelPath() and getLabelPath() in all the classifiers point to their respective models. See pic below. Note that you need to have other TFLite models for your custom dataset to run the app otherwise it will give an error. However, in this tutorial, I am only demonstrating the model for Float MobileNet, so I will remove the usages and all references of the other 3 classifiers in Android Studio.
  • I have also made another change, since we are using only 2 classes, we need to change MAX_RESULTS in and change the result size from 3 to 2 in the showResultsInBottomSheet() function in and comment out the third detected item(recognition2). See pics below. (NOTE: If you have 3 or more classes you do not have to make this change. I have shared my custom Android app on GitHub. You can find the link below in the credits section)
  • Finally, connect a mobile device and run the app. Test your app before making any other new improvements or adding more features to it. Now that you’ve made a basic Android app for custom object image classification, you can try and make changes to your TFLite model. I have also given the links to the TensorFlow site’s pages where you can learn how to apply other features and optimizations like post-training quantization etc. Read the links given below in the Documentation under the Credits section. Have fun!

Read this TensorFlow image classification app GitHub Readme page to learn more-> TensorFlow Lite image classification Android example application

My GitHub

Image Classification sample app

Custom Image Classification

Mask dataset

I have used the following dataset for test purposes. You can use other datasets. See more datasets sources mentioned below in credits section.

Prajnasb Github

My Colab notebook for this

Build Image Classification custom app

Check out my Youtube Video on this

Coming Soon!


Documentation / References

Dataset Sources

You can download datasets for many objects from the sites mentioned below. These sites also contain images of many classes of objects along with their annotations/labels in multiple formats such as the YOLO_DARKNET txt files and the PASCAL_VOC xml files.

More Mask Datasets

Leave a Reply