Machine Learning / MLKit

ML for Android Developers – Part 3

Machine Learning is very much in vogue at the moment and there are many Android folks producing some excellent guides to creating and training models. However, I do not feel that in the vast majority of cases Android Developers will be required to train their own models. But that does not mean that using a model that someone else has created is a straightforward task. In this series, we’ll take a look at some of the problems and pain points that I encountered while attempting to incorporate a pre-trained TensorFlow Lite model in to an Android app.

In the first article we examined our trained MNIST number recognition model to understand its inputs and outputs which is an absolute requirement before we can implement that model within our app; and in the second Article we looked at the benefits of hosting that model in the cloud using MLKit, and also the UI. In this final article we’ll complete our digit recognition app by hooking up the MLKit model.

The glue that ties this all together is the NumberClassifier class:

class NumberClassifier(
        provider: InterpreterProvider = InterpreterProvider(),
        private val interpreter: FirebaseModelInterpreter =
                provider.getInterpreter() ?: throw Exception("Unable to get Interpreter")
) {
    private val classCount = 10
    private val imageWidth = 28
    private val imageHeight = 28
    private val imageSize = imageWidth * imageHeight

    private val imagePixels = IntArray(imageSize)
    
    private val options = FirebaseModelInputOutputOptions.Builder()
            .setInputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, imageWidth, imageHeight, 1))
            .setOutputFormat(0, FirebaseModelDataType.FLOAT32, intArrayOf(1, classCount))
            .build()

    fun classify(
            bitmap: Bitmap,
            failure: (Exception) -> Nothing = { exception -> throw(exception) },
            success: (Int, Float, Long) -> Unit
    ) {
        val inputs = FirebaseModelInputs.Builder()
                .add(bitmap.toVector())
                .build()
        val start = System.currentTimeMillis()
        interpreter.run(inputs, options)
                .addOnSuccessListener { outputs ->
                    outputs.map().entries.maxBy { it.value }?.also {
                        success(it.key, it.value, System.currentTimeMillis() - start)
                    }
                }
                .addOnFailureListener {
                    failure(it)
                }
    }

    private fun Bitmap.toVector(): Array<Array<Array<FloatArray>>> {
        getPixels(imagePixels, 0, width, 0, 0, width, height)
        return Array(1) {
            Array(imageHeight) { y ->
                Array(imageWidth) { x ->
                    floatArrayOf(imagePixels[x + (y * imageWidth)].convertToGreyScale())
                }
            }
        }
    }

    private fun Int.convertToGreyScale(): Float =
            1f - ((Color.red(this) + Color.green(this) + Color.blue(this)).toFloat() / 3f / 255f)

    private fun FirebaseModelOutputs.map(): Map<Int, Float> {
        return getOutput<array>(0)[0].mapIndexed { index, fl -> index to fl }.toMap()
    }
}

In the constructor we obtain an interpreter for the model we obtain for the cloud through MLKit. We also define some variables which will be used to convert the image data, and pass it in to the model.

The options property defines the inputs and outputs of our model which we determined in Part 1 of this series.

The classify() method is what does the work. We first convert the Bitmap to four nested arrays – this is to match the four dimensional vector that is required for the input. An extension function to Bitmap named toVector() is where this conversion happens. The outer Array is the batch size, next is an Array of rows in the image, Inside that is an Array of the individual pixels within that row, and inside that an Array of components making up the pixel value. For an RGB image this may be three, but in our case we convert the RGB value to a float value in the range 0.0-1.0 representing a greyscale value (where 0.0 is black, and 1.0 is white). These arrays correspond to the dimensions of the input that we declared for the options property. It is important to match these otherwise you’ll get runtime errors.

We then pass this to the interpreter along with the options which define the input and output formats. This runs the interpreter to perform digit classification on the input image, and we get a set of outputs passed to the onSuccessListener which is an array of 10 floats representing the probability that the digit is each of the possible values that we’re interested in. The value at index 0 is the  probability that the digit is the number 0; The value at index 1 is the probability that the digit is the number 1; And so on.

We convert this to a map of digit to probability so that we can sort them in to descending order of probability, and then pick the first one, which is the most likely. We then make a callback to the success lambda which was passed in to this function.

All in all this is a fairly compact class which is actually doing some really quite complex analysis thanks to our trained model. However this class is actually smaller than the custom view that we created to handle the user drawing digits. However, there was one areas which caused me a lot of head scratching and took a while to get working. 

I had got everything working but the accuracy of the digit detection was pretty poor. Some digits, such as 0, 2 and 5, it was OK at detecting, but didn’t always get it right, but for others it almost always got it wrong. I wondered if it was something related to how I was building the input Arrays – rather than having the second level Array of rows, perhaps this needed to be columns instead. I went through a whole host of tweaks and changes to the structure and could easily make the detection worse, but never managed to improve it.

After much trial and error and head scratching I managed to discover that the issue was that I had indeed made an incorrect assumption about how I needed to provide the analysis image, but rather that in the organisation of the pixels, it was to do with the sample images themselves. The bitmap being passed in for analysis was a black digit drawn on a white background. But what I discovered was that the MNIST images used for training we of white digits drawn on a black background. I made a very small change to the convertToGreyScale() extension function to subtract the calculated grey value from 1.0 which would perform an inversion as part of the greyscale conversion and suddenly my detection accuracy improved enormously:

The accuracy is there and it’s also pretty fast. Once the interpreter is loaded and warmed up, the interpreter is taking around 10-15ms to detect a digit on a Pixel XL. On top of this there is the image conversion which is about the same again. But 20-30ms is pretty fast – it’s not something we’re going to be doing as part of an animation or anything. But a user isn’t going to perceive any latency whatsoever. I used coroutines to run the conversion and interpreter on a background thread, but with this kind of performance we could probably get away with running it on the UI thread. That said, this does not need to run on the UI thread, so it’s still better to put it on a background thread IMO because there may be something else that does require the UI thread while this is running.

This was actually quite a steep learning curve for me. Not so much in terms of the APIs and technologies that we’ve used to achieve this, but much more about the importance of understanding the model and data used to train it so that the data that we analyse matches this. I mentioned at the start of this series that I don’t believe that many Android developers will be responsible for creating and training their own models, but the important thing is to talk to the data scientists who are responsible for that, and understand the inputs and outputs, and how the data needs to be presented in order to get accurate analysis.

The source code for this article is available here.

© 2018, Mark Allison. All rights reserved.

Copyright © 2018 Styling Android. All Rights Reserved.
Information about how to reuse or republish this work may be available at http://blog.stylingandroid.com/license-information.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.