TensorFlow.js 在浏览器上也能搞定机器学习!

人工智能
TensorFlow 的平台就好像一个万花筒,让我看到了五彩缤纷的应用项目,同时也了解了机器学习建模和预测的底层逻辑。

在机器学习飞速发展的今天,各种机器学习平台层出不穷,为了满足不同业务场景的需求,可以将机器学习的模型分别部署到 Android、iOS、Web 浏览器,让模型在端侧能够进行推演,从而发挥模型的潜能。其中TensorFlow.js 是 TensorFlow 的 JavaScript 版本,支持 GPU 硬件加速,可以运行在 Node.js 或浏览器环境中。它不但支持完全基于 JavaScript 从头开发、训练和部署模型,也可以用来运行已有的 Python 版 TensorFlow 模型,或者基于现有的模型进行继续训练。

TensorFlow.js 优势

TensorFlow.js 可以让使用者在浏览器中加载 TensorFlow模型,让用户通过本地的 CPU/GPU 资源进行机器学习推演。浏览器中进行机器学习,相对比与服务器端来讲,将拥有以下四大优势:

1. 不需要安装软件或驱动(打开浏览器即可使用);

2. 可以通过浏览器进行更加方便的人机交互;

3. 可以通过手机浏览器,调用手机硬件的各种传感器(如:GPS、加速度传感器、摄像头等);

4. 用户的数据可以无需上传到服务器,在本地即可完成所需操作。

TensorFlow.js 架构

上面介绍了TensorFlow.js 的优势,这里让我们来了解一下TensorFlow.js 的架构。如图1 所示,TensorFlow.js 架构包括Core API 和 Layers API(图的上半部分)。其中Layers API 提供更高层次的接口,例如类似Keras API的语法结构,这些语法结构的目的是通过更加高粒度的抽象让开发人员使用JavaScript 便捷地进行机器学习的开发。而Core API主要包括TensorFlow.js 所提供的核心功能,例如Tensor的创建、数据的运算、内存管理等。同时Core API 还提供了工具将Python中的机器学习模型转换成浏览器能够使用的JSON格式,方便在JavaScript中能够复用已有的模型。因此,Core API能够在浏览器端运行,可以使用WebGL进行GPU加速,当然它也可以在Node.js 上运行,依赖具体的运行环境通过GPU、TPU进行加速。

图1  TensorFlow.js 架构

TensorFlow.js 进行线性回归的案例

前面说了TensorFlow.js 的优势和架构,这里为了大家能对TensorFlow.js 有更深的了解,我们举一个简单的线性回归例子来看看在浏览器端是如何实现机器学习的训练和推演的。

假设我们需要构建 y = ax1+bx2+c 的线性模型,如图2 所示,需要如下几个步骤完成:

1. 下载TensorFlow.js 文件

2. 训练数据和测试数据

3. 构建模型

4. 训练模型

5. 模型应用

图2 TensorFlow.js 构建线性回归模型

从这5 个步骤可以看出基本过程和在Python中构建模型是一样的,除了第一步需要下载TensorFlow.js 的文件以外。

如图3 所示,为了加载TensorFlow.js 文件,我们需要在页面的head 标签中引入script,其中文件tf.min.js 已经部署到了TensorFlow 的CDN 服务器了,我们只需要引用该文件即可。

图3 引用TensorFlow.js 文件

为了保证TensorFlow.js 文件被正确引入,如图4所示,打开浏览器并开启开发者工具,在Console 中输入tf.version 从而可以获取TensorFlow对应的tfjs-core,tfjs-backend-cpu 等信息,说明文件引入成功了。由于TensorFlow.js 文件中包含了TensorFlow的运算库,因此这里需要确保该文件被正确加载了。

图 4 确认TensorFlow.js 文件被正确引入

有了对TensorFlow.js 文件的加载之后,我们就可以在html中写入机器学习的代码了。 如图5 所示,在script标签中写入如下代码,其中async 的doTraining 方法是用来对模型进行训练的,epoch 是500 次,这里使用async 的目的是不让网页的其他操作阻塞。在函数内部调用了model 中的fit 方法对模型进行拟合,输入参数是xs 和ys,在回调函数callbacks 中输出拟合结果,并打印loss 的损失函数。

接下来就是来构造model了,这里使用了tf.sequential();构建模型,为了构建y = ax1+bx2+c 模型,这里需要构建一个神经元,这个神经元有两个输入和一个输出。

所以,通过model.add 添加一个dense 层,定义units:1,也就是一个神经元,inputShape:[2],输入是一个二维。有了模型之后,通过model.complie进行编译模型,这里使用了meanSquareError的损失函数以及optimizer为sgd。最后通过model的summary方法把整个神经元网络打印出来。紧接着在dataset环节,我们准备了xs 、ys作为输入,testData_x作为测试数据。最后,调用doTraining(model)对模型进行训练,并使用predict方法对结果进行预测。

图5 在浏览器中训练模型

将上述文件保存为html文件以后重新打开,大约1-2秒以后就可以看到图6 的结果。右边是开发者工具中打印出每次epoch 获得的loss 结果,可以看出随着训练的进程loss 损失函数是越来越小的。同时最终得到了Tensor的结果为 15.5082932 的预测结果。

图6 运行结果

TensorFlow.js 模型复用

有了上面简单的例子,我们可以在浏览器端轻松地巡检机器学习模型,但是模型的训练本事是需要耗费资源,同时也需要较长的训练时间。那么,我们能否将已经训练好的模型直接拿到浏览器进行预测和推演呢?答案是肯定的。

模型的复用一般而言有两种方式,第一种是使用开发者自己在Python中创建好的模型,通过TensroFlow提供的工具,将模型保存成tfjs格式并将其在浏览器中使用。另一种是直接调用TensorFlow 提供的模型。

图7 模型复用

开发自己定义的模型

如图8 所示,我们在python中进行模型构建、训练和保存。构建模型、神经元网络、设置优化器、损失函数以及数据准备等步骤,这里就不赘述。在模型训练完毕之后通过save_model 方法对模型进行保存。

图 8 开发自己的模型

有了模型,接着就需要使用TensorFlow.js 提供的工具对模型进行转换,才能让该模型在浏览器中被使用。

这里使用如下命令安装TensorFlow.js的工具。

pip install tensorflowjs
tensorflwjs_converter --input_format=keras_saved_model ./saved_model/ ./model/

这里使用了tensorflwjs_converter 命令对模型进行转换,input的格式是keras_saved_model,源文件地址是./saved_model/,目标文件地址是./model/,回车执行之后就可以在目标文件地址看到转换以后的文件了。

在浏览器中只需要引用这个转化好的模型文件,如图9 所示,在script中的run方法直接引用了模型文件model.json使用loadLayersModel装载模型,设置了input 之后就使用predict方法对模型进行预测了。

图9 使用转换后的模型

TensorFlow 提供的模型

上面我们演示了可以使用自己训练好的机器学习模型,这里也可以通过https://www.tensorflow.org/js/models 查找TensorFlow 提供的模型。

如图10 所示,TensorFlow 已经为一些业务场景量身打造了一些模型,例如:人像深度估测、图像分类、对象检测、身体分割、姿势检测、文本恶意检测等等。想了解如何进一步在生产场景中部署模型的同学,也可以抽空看看谷歌开发者专家对 TensorFlow 部署功能的讲解和常见问题的解答:https://zhibo.51cto.com/liveDetail/373

图 10 TensorFlow 提供的模型

通过学习TensorFlow 官方在线课程,我从一个机器学习小白成长为一个经验丰富的机器学习老手。从《TensorFlow 入门实操课程》《TensorFlow入门课程 — 部署篇》课程中,我学会了如何对机器学习模型进行保存转换,同时还可以根据不同的应用场景将机器学习模型部署到Android、iOS、浏览器以及服务端。TensorFlow 的平台就好像一个万花筒,让我看到了五彩缤纷的应用项目,同时也了解了机器学习建模和预测的底层逻辑。如果你也想让机器学习的能力有所提高,可以一起学习《TensorFlow入门课程 — 部署篇》,并留下你对课程的评价,现在报名参与,还有机会赢得官方精美礼品哦!

张云波,活跃的IT网红讲师,拥有学员31w+,国内早期开始和发布苹果Swift、安卓Kotlin、微信小程序、区块链技术的讲师之一。主攻前端开发、iOS开发、Android开发、Flutter开发、区块链Dapp开发,有丰富的大公司和海外工作经验。

责任编辑:张燕妮
相关推荐

2019-07-23 10:22:11

TensorFlow.Python机器学习

2018-09-10 14:38:16

编程语言TensorFlow.机器学习

2020-09-09 07:00:00

TensorFlow神经网络人工智能

2009-03-04 11:16:03

RABSoft浏览器控制电脑

2019-06-11 09:02:22

2017-03-03 16:50:01

2010-09-16 11:21:54

FirefoxJS

2018-06-26 15:40:49

Tensorflow.MNIST图像数据

2020-07-13 20:41:58

谷歌ChromeMac

2022-02-10 08:07:41

机器学习低代码开发

2012-05-16 09:04:53

WindowsPC浏览器

2011-06-24 10:06:13

浏览器

2019-07-24 15:25:29

框架AI开发

2022-06-12 11:12:37

GoogleChrome浏览器

2017-04-05 17:58:17

2012-03-09 09:11:29

Node.js

2012-09-03 15:27:43

搜狗浏览器

2016-02-02 10:03:15

chromeMaterial De

2020-07-17 07:21:36

TensorFlow机器学习计算机视觉

2020-09-15 08:26:25

浏览器缓存
点赞
收藏

51CTO技术栈公众号