在 Flutter 中使用 TensorFlow Lite 插件实现文字分类

如果您希望能有一种简单、高效且灵活的方式把 TensorFlow 模型集成到 Flutter 应用里,那请您一定不要错过我们今天介绍的这个全新插件 tflite_flutter。这个插件的开发者是 Google Summer of Code(GSoC) 的一名实习生 Amish Garg,本文来自他在 Medium 上的一篇文章《在 Flutter 中使用 TensorFlow Lite 插件实现文字分类》。

tflite_flutter 插件的核心特性:

  • 它提供了与 TFLite Java 和 Swift API 相似的 Dart API,所以其灵活性和在这些平台上的效果是完全一样的

  • 通过 dart:ffi 直接与 TensorFlow Lite C API 相绑定,所以它比其它平台集成方式更加高效。

  • 无需编写特定平台的代码。

  • 通过 NNAPI 提供加速支持,在 Android 上使用 GPU Delegate,在 iOS 上使用 Metal Delegate。

本文中,我们将使用 tflite_flutter 构建一个 文字分类 Flutter 应用 带您体验 tflite_flutter 插件,首先从新建一个 Flutter 项目 text_classification_app 开始。

初始化配置

Linux 和 Mac 用户

将 install.sh 拷贝到您应用的根目录,然后在根目录执行 sh install.sh,本例中就是目录 text_classification_app/

Windows 用户

将 install.bat 文件拷贝到应用根目录,并在根目录运行批处理文件 install.bat,本例中就是目录 text_classification_app/。 

它会自动从 release assets 下载最新的二进制资源,然后把它放到指定的目录下。

请点击到 README 文件里查看更多 关于初始配置的信息。

获取插件

在 pubspec.yaml 添加 tflite_flutter: ^<latest_version> (详情)。

下载模型

要在移动端上运行 TensorFlow 训练模型,我们需要使用 .tflite 格式。如果需要了解如何将 TensorFlow 训练的模型转换为 .tflite 格式,请参阅官方指南。 

这里我们准备使用 TensorFlow 官方站点上预训练的文字分类模型,可从这里下载。

该预训练的模型可以预测当前段落的情感是积极还是消极。它是基于来自 Mass 等人的  Large Movie Review Dataset v1.0 数据集进行训练的。数据集由基于 IMDB 电影评论所标记的积极或消极标签组成,点击查看更多信息。

将 text_classification.tflite 和 text_classification_vocab.txt 文件拷贝到 text_classification_app/assets/ 目录下。

在 pubspec.yaml 文件中添加 assets/

assets:    
  - assets/1

现在万事俱备,我们可以开始写代码了。 🚀

实现分类器

预处理

正如 文字分类模型页面 里所提到的。可以按照下面的步骤使用模型对段落进行分类:

  1. 对段落文本进行分词,然后使用预定义的词汇集将它转换为一组词汇 ID;

  2. 将生成的这组词汇 ID 输入 TensorFlow Lite 模型里;

  3. 从模型的输出里获取当前段落是积极或者是消极的概率值。

我们首先写一个方法对原始字符串进行分词,其中使用 text_classification_vocab.txt 作为词汇集。

在 lib/ 文件夹下创建一个新文件 classifier.dart。 

这里先写代码加载 text_classification_vocab.txt 到字典里。

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';

  Map<String, int> _dict;

  Classifier() {
    _loadDictionary();
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }

}1234567891011121314151617181920212223

加载字典

现在我们来编写一个函数对原始字符串进行分词。

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';

  // 单句的最大长度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;

  List<List<double>> tokenizeInputText(String text) {

    // 使用空格进行分词
    final toks = text.split(' ');

    // 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 的对应的字典值来填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 对于句子里的每个单词在 dict 里找到相应的 index 值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我们的解释器输入 tensor 所需的形状 [1,256] 返回 List<List<double>>
    return [vec];
  }
}123456789101112131415161718192021222324252627282930313233343536373839404142

使用 tflite_flutter 进行分析

这是本文的主体部分,这里我们会讨论 tflite_flutter 插件的用途。

这里的分析是指基于输入数据在设备上使用 TensorFlow Lite 模型的处理过程。要使用 TensorFlow Lite 模型进行分析,需要通过 解释器 来运行它,了解更多。

创建解释器,加载模型

tflite_flutter 提供了一个方法直接通过资源创建解释器。

static Future<Interpreter> fromAsset(String assetName, {InterpreterOptions options})

由于我们的模型在 assets/ 文件夹下,需要使用上面的方法来创建解析器。对于 InterpreterOptions 的相关说明,请 参考这里。

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名称
  final _modelFile = 'text_classification.tflite';

  // TensorFlow Lite 解释器对象
  Interpreter _interpreter;

  Classifier() {
    // 当分类器初始化以后加载模型
    _loadModel();
  }

  void _loadModel() async {

    // 使用 Interpreter.fromAsset 创建解释器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

}12345678910111213141516171819202122232425

创建解释器的代码

如果您不希望将模型放在 assets/ 目录下,tflite_flutter 还提供了工厂构造函数创建解释器,更多信息。

我们开始进行分析!

现在用下面方法启动分析:

void run(Object input, Object output);

注意这里的方法和 Java API 中的是一样的。

Object input 和 Object output 必须是和 Input Tensor 与 Output Tensor 维度相同的列表。

要查看  input tensors 和 output tensors 的维度,可以使用如下代码:

_interpreter.allocateTensors();
// 打印 input tensor 列表
print(_interpreter.getInputTensors());
// 打印 output tensor 列表
print(_interpreter.getOutputTensors());1234

在本例中 text_classification 模型的输出如下: 

InputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf280, name: embedding_input, type: TfLiteType.float32, shape: [1, 256], data:  1024]
OutputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf140, name: dense_1/Softmax, type: TfLiteType.float32, shape: [1, 2], data:  8]123

现在,我们实现分类方法,该方法返回值为 1 表示积极,返回值为 0 表示消极。

int classify(String rawText) {

    //  tokenizeInputText 返回形状为 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);

    // [1,2] 形状的输出
    var output = List<double>(2).reshape([1, 2]);

    // run 方法会运行分析并且存储输出的值
    _interpreter.run(input, output);

    var result = 0;
    // 如果输出中第一个元素的值比第二个大,那么句子就是消极的

    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }123456789101112131415161718192021

用于分析的代码

在 tflite_flutter 的 extension ListShape on List 下面定义了一些使用的扩展:

// 将提供的列表进行矩阵变形,输入参数为元素总数 // 保持相等 
// 用法:List(400).reshape([2,10,20]) 
// 返回  List<dynamic>

List reshape(List<int> shape)
// 返回列表的形状
List<int> get shape
// 返回列表任意形状的元素数量
int get computeNumElements123456789

最终的 classifier.dart 应该是这样的:

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名称
  final _modelFile = 'text_classification.tflite';
  final _vocabFile = 'text_classification_vocab.txt';

  // 语句的最大长度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;

  // TensorFlow Lite 解释器对象
  Interpreter _interpreter;

  Classifier() {
    // 当分类器初始化的时候加载模型
    _loadModel();
    _loadDictionary();
  }

  void _loadModel() async {
    // 使用 Intepreter.fromAsset 创建解析器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }

  int classify(String rawText) {
    // tokenizeInputText  返回形状为 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);

    //输出形状为 [1, 2] 的矩阵
    var output = List<double>(2).reshape([1, 2]);

    // run 方法会运行分析并且将结果存储在 output 中。
    _interpreter.run(input, output);

    var result = 0;
    // 如果第一个元素的输出比第二个大,那么当前语句是消极的

    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

  List<List<double>> tokenizeInputText(String text) {
    // 用空格分词
    final toks = text.split(' ');

    // 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 对应的字典值来填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 对于句子中的每个单词,在 dict 中找到相应的 index 值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我们的解释器输入 tensor 所需的形状 [1,256] 返回 List<List<double>>
    return [vec];
  }
}123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293

现在,可以根据您的喜好实现 UI 的代码,分类器的用法比较简单。

// 创建 Classifier 对象
Classifer _classifier = Classifier();
// 将目标语句作为参数,调用 classify 方法
_classifier.classify("I liked the movie");
// 返回 1 (积极的)
_classifier.classify("I didn't liked the movie");
// 返回 0 (消极的)123456

请在这里查阅完整代码:Text Classification Example app with UI。

Text Classification Example App

文字分类示例应用

了解更多关于 tflite_flutter 插件的信息,请访问 GitHub repo: am15h/tflite_flutter_plugin

答疑

问:tflite_flutter 和 tflite v1.0.5 有哪些区别?

tflite v1.0.5 侧重于为特定用途的应用场景提供高级特性,比如图片分类、物体检测等等。而新的 tflite_flutter 则提供了与 Java API 相同的特性和灵活性,而且可以用于任何 tflite 模型中,它还支持 delegate。

由于使用 dart:ffi (dart ↔️ (ffi) ↔️ C),tflite_flutter 非常快 (拥有低延时)。而 tflite 使用平台集成 (dart ↔️ platform-channel ↔️ (Java/Swift) ↔️ JNI ↔️ C)。

问:如何使用 tflite_flutter 创建图片分类应用?有没有类似 TensorFlow Lite Android Support Library 的依赖包?

更新(07/01/2020): TFLite Flutter Helper 开发库已发布。

TensorFlow Lite Flutter Helper Library 为处理和控制输入及输出的 TFLite 模型提供了易用的架构。它的 API 设计和文档与 TensorFlow Lite Android Support Library 是一样的。更多信息请 参考这里。

以上是本文的全部内容,欢迎大家对 tflite_flutter 插件进行反馈,请在这里 上报 bug 或提出功能需求。

谢谢关注。

感谢 Michael Thomsen。

致谢

  • 译者:Yuan,谷创字幕组

  • 审校:Xinlei、Lynn Wang、Alex,CFUG 社区。


热门文章

编程学习 ·

yum本地云搭建

yum Yum(全称为 Yellow dog Updater, Modified)是一个在Fedora和RedHat以及CentOS中的Shell前端软件包管理器。基于RPM包管理,能够从指定的服务器自动下载RPM包并且安装,可以自动处理依赖性关系,并且一次安装所有依赖的软件包,无须繁琐地一次次下载、安装。 yum仓库配置文…
编程学习 ·

JVM——Java的内存回收

Java引用的种类对于JVM的垃圾回收机制来说,如果一个对象,没有一个引用指向它,那么它就被认为是一个垃圾。那该对象就会回收。可以把JVM内存中对象引用理解成一种有向图,把引用变量、对象都当成有向图的顶点,将引用关系当成图的有向边(注意:有向边总是从引用变量指向被引…
编程学习 ·

两种判断对象类型的方法

两种判断对象类型的方法: 1.通过instanceof *缺点:不能准确的判断该对象是Dog的实例,如果该对象是类的子类对象也会返回true 2.对象.getClass().getName()获取对象的实例类名 (1)对象.getClass():返回该对象对应的Class对象 (2)对象.getClass().getName():该对象对应的class对…
编程学习 ·

IT系统稳定性创新者:分布式软件,“笨马”先跑

(PerfMa CEO 李嘉鹏)早在2006年前后,IT系统稳定性就成为了当时集中式架构的挑战。随着互联网的快速兴起,当时的“Unix+小型机”架构遭遇了数据爆增的冲击。特别是在线交易、商业分析和数据库等关键业务系统,在2010年前后进入了TB甚至PB级,导致传统IT架构不堪重负,对IT系统…
编程学习 ·

CMDB可用于那些服务和流程

CMDB可用于那些服务和流程CMDB不应该有哪些功能工单流程管理工单流程管理是一种流程管理手段,通过提交工单,逐级审批的方式,实现流程的流转,并可以提供回调Hook来自动执行某些操作。这样一个工单流程管理的功能,不仅需要对工单流程有详尽的了解,还需要对每个流程进行定制…
编程学习 ·

Java小型计算器

通过对程序的编写,可以不同位数实现对加减乘除的计算 ,以及对错误答案给出提示。以满足一些大人给小孩出题的困惑 ,此程序可以自己出题,自己检测答案。随时随地想做就做。需求:1.实现计算器的基本功能。 2.可以练习加减乘除的计算,以提高自己算题的速度 和探索新的解题方…
编程学习 ·

Centos7x破解密码

办法一 1、开机启动部分1)开机e 选择第一行 e 2)找到Linux16所在行 ***.UTF8后面添加 rd.break console=tty0 3)ctrl + X2、启动到内核部分1)挂载/sysroot目录 #mount -o remount,rw /sysroot 2)切换到/sysroot目录 #chroot /sysroot 3)修改root密码 #echo passwo…
编程学习 ·

冒泡、选择、插入排序算法(c语言)实现

几种常见排序算法的实现 一、冒泡排序 1.百度百科 冒泡排序(Bubble Sort),是一种计算机科学领域的较简单的排序算法。它重复地走访过要排序的元素列,依次比较两个相邻的元素,如果顺序(如从大到小、首字母从Z到A)错误就把他们交换过来。走访元素的工作是重复地进行直到没…
编程学习 ·

啥是智慧社区-百度人脸识别

还记得前几年大家常说的看“脸”的时代吗?现如今回家必须得看脸了,人脸识别助力智慧社区管理升级,以前我们只能在电影里看到了,刷脸进出小区,刷脸开锁等在现实中已经实现了,那使用了人脸识别的智慧社区到底是个啥?下面AI人工智能带大家一探究竟。1、人工智能赋予美好生活…
编程学习 ·

Vue——09——v-for和key指令

遍历普通数组 <!DOCTYPE html> <html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>Document</title><scri…
编程学习 ·

elementUI From表单踩坑之watch 变量监控

-当修改input框内的值(form.name)的时候,watch 监控from失败,watch中的from不相应,打印无效;<el-form ref="form" :model="form" label-width="80px"><el-form-item label="活动名称"><el-input v-model="f…
编程学习 ·

一周信创舆情观察(6.22~6.28)

一、一周舆情要点 第四届世界智能大会本周成功举办,技术服务项目由腾讯云提供支持。大会云签约148个项目,其中内资项目投资809亿元,外资项目投资约16亿美元。会议期间,天津港集团和华为签署战略合作协议,双方将加强信息化顶层设计及智慧港口合作。 数据安全监管趋严,网安…
编程学习 ·

windows10设置jdk环境变量

先安装好jdk 电脑–>属性–>高级系统设置–>新建 点击Path,编辑 添加这两项,确定 win+R打开cmd java -version检查jdk环境变量是否配置成功
编程学习 ·

kuangbin专题8 生成树 次小生成树部分 HDU4081/UVA10600/UVA10462

前言 本来壮志凌云的想都做完 发现我在做梦。。。 朱刘算法太难了(自己太懒发现性价比比较低之后就没做而且算法介绍也太难懂了好几个关键词含义都不给简直简直太难了我枯 HDU4081 Qin Shi Huang’s National Road System 题意:给你一个图的各个点的坐标 再给你每个点的权值…
编程学习 ·

JetPack 之 Paging3.0 简单上手指南!

作者:Chsmy之前有一篇Paging2.x的使用和分析,Paging2.x运行起来的效果无限滑动还挺不错的,不过代码写起来有点麻烦,功能也不是太完善,比如下拉刷新的方法都没有提供,我们还得自己去调用DataSource#invalidate()方法重置数据来实现。最近google出了3.0的测试版,功能更加强…
编程学习 ·

前端性能优化

浏览器渲染机制 Html解析成DOM树,Css解析成CSS树,将DOM树与CSSDOM规则树合并在一起生成Render树,遍历渲染树开始布局,计算每个节点的位置大小信息,将渲染树每个节点绘制到屏幕阻塞渲染当浏览器遇到一个script标记时,DOM构建将暂停,直至脚本完成执行,然后继续构建DOM。每…
编程学习 ·

创新实训—动画小插件开发实践

基于现今动画行业的发展越来越快,为了有效提供动画制作人员的工作效率,许多动画制作软件诸如maya、3d max以及blender越来越注意软件的高效化,无数的插件慢慢地被开发。我们小组使用Python去开发相应的插件,以加快动画制作人员的制作效率1、首先Pyqt的搭建这也是我第一次使…
编程学习 ·

UE4学习-添加机关并添加代码控制

文章目录添加机关代码编写给密室添加屋顶打印日志控制系统角色创建一个新游戏模式替换DefaultPawn添加抓取组件获取起点和终点物体拾取,碰撞属性设置今日完整代码 添加机关 首先向场景里面添加一个聚光源添加聚光源以后,可以对其属性进行修改,如图:然后需要给聚光源添加一个…
编程学习 ·

AJAX

原生AJAX ajax概念:在不进行整个页面的更新的情况下,局部更新界面。 局部刷新技术 ajax 和请求数据有关 它的出现开始前后端分离 ajax出现之前 开发人员前端和后端都做 前后端分工 中间由ajax来对接 ajax就是异步的javascript和xml(树形结构文档 xhml—写法和html写法一…
编程学习 ·

程序人生 - 西瓜霜能吃下去吗?

西瓜霜是可以吃下去的,但是会影响到临床的药效,临床常用的有西瓜霜和西瓜霜润喉片主要药理作用就是消肿止痛,清咽利嗓,多用于急慢性咽喉炎,扁桃体炎,口腔溃疡,口舌生疮等引起的咽喉疼痛,声音嘶哑,牙龈红肿。通常,主要是局部用药,这样才能够更好的发挥疗效,成年人一…