Machine Learning / MLKit

ML for Android Developers – Part 2

Machine Learning is very much in vogue at the moment and there are many Android folks producing some excellent guides to creating and training models. However, I do not feel that in the vast majority of cases Android Developers will be required to train their own models. But that does not mean that using a model that someone else has created is a straightforward task. In this series, we’ll take a look at some of the problems and pain points that I encountered while attempting to incorporate a pre-trained TensorFlow Lite model in to an Android app.

Previously we examined our trained MNIST number recognition model to understand its inputs and outputs which is an absolute requirement before we can implement that model within our app. 

I mentioned in the previous article that one of my must-haves was to use MLKit rather than directly invoking the TensorFlow Lite Android runtime. Although MLKit uses this internally, it offers one huge benefit: You can store your TensorFlow Lite model in the cloud and this will get automatically downloaded to the client.

Although it’s not a bad idea to bundle a version of the model in with your APK, it’s really not a good idea to only use that mechanism to distribute the model. The reason for this is if you need to update the model, then it requires a new APK and app update to be made each time the model changes. That can be painful if you have a separate team responsible for updating the model and they wish to make new releases which are not in sync with the app development team’s planned release schedule.

It therefore makes more sense to periodically check a server and download a new model whenever it changes. This will allow the model team to be able to update the model independently of app releases. The only requirement here is that they can refine the internals of the model, and re-train an existing model, but they cannot change the input & output formats – changing these will need to be tied in with a new version of the app which works with the new input & output formats.

While it would not be too difficult to host the model file on a server and poll periodically from the app and download a new model whenever it changes, MLKit offers this functionality for free, and is the reason that I decided that using MLKit was an absolute requirement.

So let’s look at the code required to obtain the model:

class InterpreterProvider {

    private val options: FirebaseModelOptions

    init {
        val initialConditions = FirebaseModelDownloadConditions.Builder()
                .build()
        val updateConditions = FirebaseModelDownloadConditions.Builder().run {
            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
                requireCharging()
                requireDeviceIdle()
            }
            requireWifi()
            build()
        }
        FirebaseCloudModelSource.Builder("mnist")
                .enableModelUpdates(true)
                .setInitialDownloadConditions(initialConditions)
                .setUpdatesDownloadConditions(updateConditions)
                .build().also { source ->
                    FirebaseModelManager.getInstance().registerCloudModelSource(source)
                }
        options = FirebaseModelOptions.Builder()
                .setCloudModelName("mnist")
                .build()
    }

    fun getInterpreter() = FirebaseModelInterpreter.getInstance(options)
}

The init block is mainly setting the options for how we obtain the model. We create a couple of conditions sets which will control the circumstances under which we’ll download the model. The initial conditions control how the model will be downloaded the initially, and have little to no restrictions. We want the user to have the model initially. However we are more reserved with the updates conditions which control how updated models will be downloaded. For these we require the device to be charging (so we don’t waste battery), the device to be idle (so we don’t cause the device to slow down while in use), and we require wifi (to avoid using the user’s data). 

If we were bundling a copy of the model in the APK, then I would be a bit stricter with the initial conditions, but the app will be pretty useless without a model, so I’ve take this approach.

Once we have defined these, we build and register a FirebaseCloudModelSource. Finally we create the FirebaseModelOptions which we’ll use to obtain Interpreter instances.

The only other thing that we need to do is ensure that we have declared the INTERNET permission in our Manifest. With these relatively few lines of code we’ll get our model downloaded from the cloud and it will periodically check for updated models and download them as necessary.

All that remains is to upload the model to this app instance in the Firebase console (we’ll need to connect the app to Firebase using the Firebase Assistant in Android Studio. Just select any of the items in the Assistant and click on the link to enable it. Then use the “Connect your Firebase App” option to connect you app, and ignore the rest. Now that your app is connected, you’ll see it in Firebase console, and you can upload a custom tflite model in the MLKit section:

I should mention that the project source code does not contain the google-services.json file that the Firebase Assistant added to my project. This is because it contains details specific to my Firebase account, so anyone wanting to use the sample code will need to use their own.

So with that in place we can now obtain the model from the cloud, and we can update the model as we get newer versions. Next we’ll turn our attention to the UI which will be a canvas upon which the user can draw a digit, with a TextView to display the digit detected, and a button to clear the canvas:

The drawing canvas is quite a simple custom view which handles touch events and constructs a Path object which gets drawn to the Canvas:

class FingerCanvasView @JvmOverloads constructor(
        context: Context,
        attrs: AttributeSet? = null,
        defStyleRes: Int = 0
) : View(context, attrs, defStyleRes) {

    private val path = Path()
    private val paint = Paint().apply {
        isAntiAlias = true
        isDither = true
        color = Color.BLACK
        style = Paint.Style.STROKE
        strokeJoin = Paint.Join.ROUND
        strokeCap = Paint.Cap.ROUND
        strokeWidth = 150f
    }

    private var lastX: Float = 0f
    private var lastY: Float = 0f

    private val tolerance = 4
    private val bitmapMinimumSize = 28f

    private var scaleFactor = 1f
    private lateinit var bitmap: Bitmap
    private lateinit var bitmapCanvas: Canvas

    override fun onDraw(canvas: Canvas) {
        canvas.drawPath(path, paint)
    }

    override fun onSizeChanged(w: Int, h: Int, oldw: Int, oldh: Int) {
        super.onSizeChanged(w, h, oldw, oldh)
        scaleFactor = bitmapMinimumSize / Math.min(w, h).toFloat()

        bitmap = Bitmap.createBitmap(
                Math.round(w.toFloat() * scaleFactor),
                Math.round(h.toFloat() * scaleFactor),
                Bitmap.Config.RGB_565
        )
        bitmapCanvas = Canvas(bitmap).apply {
            scale(scaleFactor, scaleFactor)
        }
    }

    override fun onTouchEvent(event: MotionEvent): Boolean {
        when (event.action) {
            MotionEvent.ACTION_DOWN -> touchStart(event.x, event.y)
            MotionEvent.ACTION_MOVE -> touchMove(event.x, event.y)
            MotionEvent.ACTION_UP -> {
                touchEnd()
                performClick()
            }
        }
        return true
    }

    private fun touchStart(x: Float, y: Float) {
        path.moveTo(x, y)
        lastX = x
        lastY = y
        invalidate()
    }

    private fun touchMove(x: Float, y: Float) {
        val deltaX = Math.abs(x - lastX)
        val deltaY = Math.abs(y - lastY)
        if (deltaX >= tolerance || deltaY >= tolerance) {
            path.quadTo(lastX, lastY, (x + lastX) / 2f, (y + lastY) / 2f)
            lastX = x
            lastY = y
            invalidate()
        }
    }

    var drawingListener: (Bitmap) -> Unit = {}

    fun clear() {
        path.reset()
        invalidate()
    }

    private fun touchEnd() {
        path.lineTo(lastX, lastY)
    }

    override fun performClick(): Boolean {
        super.performClick()

        bitmapCanvas.drawColor(Color.WHITE)
        bitmapCanvas.drawPath(path, paint)
        drawingListener(bitmap)
        invalidate()
        return true
    }
}

I’m not going to give a full explanation of this as the focus of this series is implementing an ML model, but the key thing is that when it receives an ACTION_UP event because the user lifts their finger from the screen, as well as drawing to the screen, we also draw the path to a bitmap and pass this to a drawing listener. The bitmap is actually a 28×28 pixel bitmap to match the training data for our MNIST model, and it is important to match our analysis data format to the training data format. It also makes the data much faster to analyse and categorise if it is relatively small.

We wire up the various controls in the Activity:

class MainActivity : AppCompatActivity() {

    private lateinit var numberClassifier: NumberClassifier

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        numberClassifier = NumberClassifier()

        finger_canvas.drawingListener = { bitmap ->
            launch(CommonPool) {
                val start = System.currentTimeMillis()
                numberClassifier.classify(bitmap) { result, confidence, elapsed ->
                    val total = System.currentTimeMillis() - start
                    println("Result: $result, confidence: $confidence, elapsed: ${total}ms total, ${elapsed}ms in ML")
                    launch(UI) {
                        digit.text  = result.toString()
                    }
                }
            }
        }

        button_clear.setOnClickListener {
            finger_canvas.clear()
            digit.text = ""
        }
    }
}

Here we have the drawingListener implementation which will pass the bitmap to a NumberClassifier instance which is the ML model implementation. In the final article in this series we’ll take a look at that class and cover a real pain point which I experienced in trying to get it working.

Although we don’t have fully working digit detection app yet, the source code we have so far is available here.

© 2018, Mark Allison. All rights reserved.

Copyright © 2018 Styling Android. All Rights Reserved.
Information about how to reuse or republish this work may be available at http://blog.stylingandroid.com/license-information.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.