CreateML 使用以及在 iOS 中应用介绍

移动开发 iOS
Create ML 是苹果于2018年 WWDC 推出的生成机器学习模型的工具。它可以接收用户给定的数据,生成 iOS 开发中需要的机器学习模型(Core ML 模型)。

aPaaS Growth 团队专注在用户可感知的、宏观的 aPaaS 应用的搭建流程,及租户、应用治理等产品路径,致力于打造 aPaaS 平台流畅的 “应用交付” 流程和体验,完善应用构建相关的生态,加强应用搭建的便捷性和可靠性,提升应用的整体性能,从而助力 aPaaS 的用户增长,与基础团队一起推进 aPaaS 在企业内外部的落地与提效。

在低代码/无代码领域,例如 MS Power Platform,AWS 的 Amplify 都有类似于 AI Builder 的产品,这些产品主要让用户很低门槛训练自己的深度学习模型,CreateML 是苹果生态下的产品,工具上伴随 XCode 下发,安装了 XCode 的同学也可以打开来体验一下(得自己准备数据集)。

什么是 CreateML

图片

Create ML 是苹果于2018年 WWDC 推出的生成机器学习模型的工具。它可以接收用户给定的数据,生成 iOS 开发中需要的机器学习模型(Core ML 模型)。

iOS 开发中,机器学习模型的获取主要有以下几种:

  • 从苹果的官方主页[1]下载现成的模型。2017年有4个现成的模型,2018年有6个,2019年增加到了9个(8个图片、1个文字),今年进展到了 13,数量有限,进步速度缓慢,但是这些模型都是比较实用的,能在手机上在用户体验允许的情况下能够跑起来的。
  • 用第三方的机器学习框架生成模型,再用 Core ML Tools 转成 Core ML 模型。2017年苹果宣布支持的框架有6个,包括 Caffee、Keras。2018年宣布支持的第三方框架增加到了11个,包括了最知名的 TensorFlow、IBM Watson、MXNet。至此 Core ML 已经完全支持市面上所有主流的框架。
  • 用 Create ML 直接训练数据生成模型。2018年推出的初代 Create ML有三个特性:使用 Swift 编程进行操作、用 Playground 训练和生成模型、在 Mac OS 上完成所有工作。

今年的 Create ML 在易用性上更进一步:无需编程即可完成操作、独立成单独的 Mac OS App、支持更多的数据类型和使用场景。

CreateML 模型列表

图片

1、Image Classification:图片分类

图片

2、Object Detection:

图片

3、Style Transfer

图片

4、Hand Pose & Hand Action

图片

5、Action Classification

图片

6、Activity Classification

图片

图片

图片

7、Sound Classification

想象一下「Hey Siri」实现

图片

8、Text Classification

图片

9、Word Tagging

图片

10、Tabular Classification & Regression

图片

通过若干个维度,预测另外一个维度,例如通过性别、年龄、城市等推断你的收入级别。

11、Recommendation

例如你买了啤酒,推荐你买花生。历史上的也有一些不是基于深度学习的算法,例如 Apriori 等。

CreateML 模型尝鲜

图片

训练一个目标检测的 CreateML 模型

数据准备

有些同学可能认为觉得训练深度模型的难点在于找到适当的算法/模型、在足够强的机器下训练足够多的迭代次数。但是事实上,对于深度模型来说,最最最关键的是具有足够多的、精确的数据源,这也是 AI 行业容易形成头部效应最主要原因。假设你在做一个 AI 相关的应用,最主要需要关注的是如何拥有足够多的、精确的数据源。

下面我就与上面「尝鲜」的模型为例,讲述如何训练类似模型的。

数据格式

CreateML 目标检测的数据格式如下图:

图片

首先会有一个叫 annotions.json 的文件,这个文件会标注每个文件里有多少个目标,以及目标的 Bounding Box 的坐标是什么。

图片

例如上图对应的 Bounding Box 如下:

图片

准备足够多的数据

第一个问题是,什么才叫足够多的数据,我们可以看一些 Dataset 来参考一下:

Standford Cars Dataset: 934MB. The Cars dataset contains 16,185 images of 196 classes of cars. The data is split into 8,144 training images and 8,041 testing images。

https://www.kaggle.com/datasets/kmader/food41: Labeled food images in 101 categories from apple pies to waffles, 6GB

在上面这个例子里,原神的角色有大概 40 多个,所以我们需要准备大概百来 MB 的数据来训练作为起来,当精确度不高的时候,再增加样本的数量来增加精度。问题是我们去哪里找那么多数据呢?所以我想到的一个方法是通过脚本来合成,因为我们的问题只是定位提取图片中的角色「证件照」,我用大概 40 来角色的证件照,写了如下的脚本(colipot helped a alot ...)来生成大概 500MB 的测试训练集:

// import sharp from "sharp";

import { createCanvas, Image } from "@napi-rs/canvas";
import { promises } from "fs";
import fs from "fs";
import path from "path";
import Sharp from "sharp";

const IMAGE_GENERATED_COUNT_PER_CLASS = 5;
const MAX_NUMBER_OF_CLASSES_IN_SINGLE_IMAGE = 10;
const CANVAS_WIDTH = 1024;
const CANVAS_HEIGHT = 800;
const CONCURRENT_PROMISE_SIZE = 50;

const CanvasSize = [CANVAS_WIDTH, CANVAS_HEIGHT];

function isNotOverlap(x1: number, y1: number, width1: number, height1: number, x2: number, y2: number, width2: number, height2: number) {
return x1 >= x2 + width2 || x1 + width1 <= x2 || y1 >= y2 + height2 || y1 + height1 <= y2;
}

const randomColorList: Record<string, string> = {
"white": "rgb(255, 255, 255)",
"black": "rgb(0, 0, 0)",
"red": "rgb(255, 0, 0)",
"green": "rgb(0, 255, 0)",
"blue": "rgb(0, 0, 255)",
"yellow": "rgb(255, 255, 0)",
"cyan": "rgb(0, 255, 255)",
"magenta": "rgb(255, 0, 255)",
"gray": "rgb(128, 128, 128)",
"grey": "rgb(128, 128, 128)",
"maroon": "rgb(128, 0, 0)",
"olive": "rgb(128, 128, 0)",
"purple": "rgb(128, 0, 128)",
"teal": "rgb(0, 128, 128)",
"navy": "rgb(0, 0, 128)",
"orange": "rgb(255, 165, 0)",
"aliceblue": "rgb(240, 248, 255)",
"antiquewhite": "rgb(250, 235, 215)",
"aquamarine": "rgb(127, 255, 212)",
"azure": "rgb(240, 255, 255)",
"beige": "rgb(245, 245, 220)",
"bisque": "rgb(255, 228, 196)",
"blanchedalmond": "rgb(255, 235, 205)",
"blueviolet": "rgb(138, 43, 226)",
"brown": "rgb(165, 42, 42)",
"burlywood": "rgb(222, 184, 135)",
"cadetblue": "rgb(95, 158, 160)",
"chartreuse": "rgb(127, 255, 0)",
"chocolate": "rgb(210, 105, 30)",
"coral": "rgb(255, 127, 80)",
"cornflowerblue": "rgb(100, 149, 237)",
"cornsilk": "rgb(255, 248, 220)",
"crimson": "rgb(220, 20, 60)",
"darkblue": "rgb(0, 0, 139)",
"darkcyan": "rgb(0, 139, 139)",
"darkgoldenrod": "rgb(184, 134, 11)",
"darkgray": "rgb(169, 169, 169)",
"darkgreen": "rgb(0, 100, 0)",
"darkgrey": "rgb(169, 169, 169)",
"darkkhaki": "rgb(189, 183, 107)",
"darkmagenta": "rgb(139, 0, 139)",
"darkolivegreen": "rgb(85, 107, 47)",
"darkorange": "rgb(255, 140, 0)",
"darkorchid": "rgb(153, 50, 204)",
"darkred": "rgb(139, 0, 0)"
}

function generateColor(index: number = -1) {
if (index < 0 || index > Object.keys(randomColorList).length) {
// return random color from list
let keys = Object.keys(randomColorList);
let randomKey = keys[Math.floor(Math.random() * keys.length)];
return randomColorList[randomKey];
} else {
// return color by index
let keys = Object.keys(randomColorList);
return randomColorList[keys[index]];
}
}

function randomPlaceImagesInCanvas(canvasWidth: number, canvasHeight: number, images: number[][], overlapping: boolean = true) {
let placedImages: number[][] = [];
for (let image of images) {
let [width, height] = image;
let [x, y] = [Math.floor(Math.random() * (canvasWidth - width)), Math.floor(Math.random() * (canvasHeight - height))];
let placed = false;
for (let placedImage of placedImages) {
let [placedImageX, placedImageY, placedImageWidth, placedImageHeight] = placedImage;
if (overlapping || isNotOverlap(x, y, width, height, placedImageX, placedImageY, placedImageWidth, placedImageHeight)) {
placed = true;
}
}
placedImages.push([x, y, placed ? 1 : 0]);
}
return placedImages;
}

function getSizeBasedOnRatio(width: number, height: number, ratio: number) {
return [width * ratio, height];
}

function cartesianProductOfArray(...arrays: any[][]) {
return arrays.reduce((a, b) => a.flatMap((d: any) => b.map((e: any) => [d, e].flat())));
}

function rotateRectangleAndGetSize(width: number, height: number, angle: number) {
let radians = angle * Math.PI / 180;
let cos = Math.abs(Math.cos(radians));
let sin = Math.abs(Math.sin(radians));
let newWidth = Math.ceil(width * cos + height * sin);
let newHeight = Math.ceil(height * cos + width * sin);
return [newWidth, newHeight];
}

function concurrentlyExecutePromisesWithSize(promises: Promise<any>[], size: number): Promise<void> {
let promisesToExecute = promises.slice(0, size);
let promisesToWait = promises.slice(size);
return Promise.all(promisesToExecute).then(() => {
if (promisesToWait.length > 0) {
return concurrentlyExecutePromisesWithSize(promisesToWait, size);
}
});
}

function generateRandomRgbColor() {
return [Math.floor(Math.random() * 256), Math.floor(Math.random() * 256), Math.floor(Math.random() * 256)];
}

function getSizeOfImage(image: Image) {
return [image.width, image.height];
}

async function makeSureFolderExists(path: string) {
if (!fs.existsSync(path)) {
await promises.mkdir(path, { recursive: true });
}
}

// non repeatly select elements from array
async function randomSelectFromArray<T>(array: T[], count: number) {
let copied = array.slice();
let selected: T[] = [];
for (let i = 0; i < count; i++) {
let index = Math.floor(Math.random() * copied.length);
selected.push(copied[index]);
copied.splice(index, 1);
}
return selected;
}

function getFileNameFromPathWithoutPrefix(path: string) {
return path.split("/").pop()!.split(".")[0];
}

type Annotion = {
"image": string,
"annotions": {
"label": string,
"coordinates": {
"x": number,
"y": number,
"width": number,
"height": number
}
}[]
}

async function generateCreateMLFormatOutput(folderPath: string, outputDir: string, imageCountPerFile: number = IMAGE_GENERATED_COUNT_PER_CLASS) {

if (!fs.existsSync(path.join(folderPath, "real"))) {
throw new Error("real folder does not exist");
}

let realFiles = fs.readdirSync(path.join(folderPath, "real")).map((file) => path.join(folderPath, "real", file));
let confusionFiles: string[] = [];

if (fs.existsSync(path.join(folderPath, "confusion"))) {
confusionFiles = fs.readdirSync(path.join(folderPath, "confusion")).map((file) => path.join(folderPath, "confusion", file));
}

// getting files in folder
let tasks: Promise<void>[] = [];
let annotions: Annotion[] = [];

for (let filePath of realFiles) {

let className = getFileNameFromPathWithoutPrefix(filePath);

for (let i = 0; i < imageCountPerFile; i++) {

let annotion: Annotion = {
"image": `${className}-${i}.jpg`,
"annotions": []
};

async function __task(i: number) {

let randomCount = Math.random() * MAX_NUMBER_OF_CLASSES_IN_SINGLE_IMAGE;
randomCount = randomCount > realFiles.length + confusionFiles.length ? realFiles.length + confusionFiles.length : randomCount;
let selectedFiles = await randomSelectFromArray(realFiles.concat(confusionFiles), randomCount);
if (selectedFiles.includes(filePath)) {
// move filePath to the first
selectedFiles.splice(selectedFiles.indexOf(filePath), 1);
selectedFiles.unshift(filePath);
} else {
selectedFiles.unshift(filePath);
}

console.log(`processing ${filePath} ${i}, selected ${selectedFiles.length} files`);

let images = await Promise.all(selectedFiles.map(async (filePath) => {
let file = await promises.readFile(filePath);
let image = new Image();
image.src = file;
return image;
}));

console.log(`processing: ${filePath}, loaded images, start to place images in canvas`);

let imageSizes = images.map(getSizeOfImage).map( x => {
let averageX = CanvasSize[0] / (images.length + 1);
let averageY = CanvasSize[1] / (images.length + 1);
return [x[0] > averageX ? averageX : x[0], x[1] > averageY ? averageY : x[1]];
});

let placedPoints = randomPlaceImagesInCanvas(CANVAS_WIDTH, CANVAS_HEIGHT, imageSizes, false);

console.log(`processing: ${filePath}, placed images in canvas, start to draw images`);

let angle = 0;
let color = generateColor(i);

let [canvasWidth, canvasHeight] = CanvasSize;
const canvas = createCanvas(canvasWidth, canvasHeight);
const ctx = canvas.getContext("2d");

ctx.fillStyle = color;
ctx.fillRect(0, 0, canvasWidth, canvasHeight);

for (let j = 0; j < images.length; j++) {
const ctx = canvas.getContext("2d");

let ratio = Math.random() * 1.5 + 0.5;

let image = images[j];

let [_imageWidth, _imageHeight] = imageSizes[j];
let [imageWidth, imageHeight] = getSizeBasedOnRatio(_imageWidth, _imageHeight, ratio);

let placed = placedPoints[j][2] === 1 ? true : false;
if (!placed) {
continue;
}

let targetX = placedPoints[j][0] > imageWidth / 2 ? placedPoints[j][0] : imageWidth / 2;
let targetY = placedPoints[j][1] > imageHeight / 2 ? placedPoints[j][1] : imageHeight / 2;

let sizeAfterRotatation = rotateRectangleAndGetSize(imageWidth, imageHeight, angle);

console.log("final: ", [canvasWidth, canvasHeight], [imageWidth, imageHeight], [targetX, targetY], angle, ratio, color);

ctx.translate(targetX, targetY);
ctx.rotate(angle * Math.PI / 180);

ctx.drawImage(image, -imageWidth / 2, -imageHeight / 2, imageWidth, imageHeight);

ctx.rotate(-angle * Math.PI / 180);
ctx.translate(-targetX, -targetY);

// ctx.fillStyle = "green";
// ctx.strokeRect(targetX - sizeAfterRotatation[0] / 2, targetY - sizeAfterRotatation[1] / 2, sizeAfterRotatation[0], sizeAfterRotatation[1]);

annotion.annotions.push({
"label": getFileNameFromPathWithoutPrefix(selectedFiles[j]),
"coordinates": {
"x": targetX,
"y": targetY,
"width": sizeAfterRotatation[0],
"height": sizeAfterRotatation[1]
}
});
}

if (!annotion.annotions.length) {
return;
}

let fileName = path.join(outputDir, `${className}-${i}.jpg`);
let pngData = await canvas.encode("jpeg");
await promises.writeFile(fileName, pngData);

annotions.push(annotion);
}

tasks.push(__task(i));

}

}

await concurrentlyExecutePromisesWithSize(tasks, CONCURRENT_PROMISE_SIZE);

await promises.writeFile(path.join(outputDir, "annotions.json"), JSON.stringify(annotions, null, 4));

}

async function generateYoloFormatOutput(folderPath: string) {
const annotions = JSON.parse((await promises.readFile(path.join(folderPath, "annotions.json"))).toString("utf-8")) as Annotion[];

// generate data.yml
let classes: string[] = [];
for (let annotion of annotions) {
for (let label of annotion.annotions.map(a => a.label)) {
if (!classes.includes(label)) {
classes.push(label);
}
}
}

let dataYml = `
train: ./train/images
val: ./valid/images
test: ./test/images

nc: ${classes.length}
names: ${JSON.stringify(classes)}
`
await promises.writeFile(path.join(folderPath, "data.yml"), dataYml);

const weights = [0.85, 0.90, 0.95];
const split = ["train", "valid", "test"];

let tasks: Promise<void>[] = [];

async function __task(annotion: Annotion) {
const randomSeed = Math.random();
let index = 0;
for (let i = 0; i < weights.length; i++) {
if (randomSeed < weights[i]) {
index = i;
break;
}
}
let splitFolderName = split[index];
await makeSureFolderExists(path.join(folderPath, splitFolderName));
await makeSureFolderExists(path.join(folderPath, splitFolderName, "images"));
await makeSureFolderExists(path.join(folderPath, splitFolderName, "labels"));

// get info of image
let image = await Sharp(path.join(folderPath, annotion.image)).metadata();

// generate label files
let line: [number, number, number, number, number][] = []
for (let i of annotion.annotions) {
line.push([
classes.indexOf(i.label),
i.coordinates.x / image.width!,
i.coordinates.y / image.height!,
i.coordinates.width / image.width!,
i.coordinates.height / image.height!
])
}

await promises.rename(path.join(folderPath, annotion.image), path.join(folderPath, splitFolderName, "images", annotion.image));
await promises.writeFile(path.join(folderPath, splitFolderName, "labels", annotion.image.replace(".jpg", ".txt")), line.map(l => l.join(" ")).join("\n"));
}

for (let annotion of annotions) {
tasks.push(__task(annotion));
}

await concurrentlyExecutePromisesWithSize(tasks, CONCURRENT_PROMISE_SIZE);

}

(async () => {

await generateCreateMLFormatOutput("./database", "./output");

// await generateYoloFormatOutput("./output");

})();

这个脚本的思路大概是将这 40 多张图片随意揉成各种可能的形状,然后选取若干张把它撒在画布上,画布的背景也是随机的,用来模拟足够多的场景。

顺带一说,上面 500MB 这个量级并不是一下子就定好的,而是不断试验,为了更高的准确度一步一步地提高量级。

模型训练

下一步就比较简单了,在 CreateML 上选取你的数据集,然后就可以训练了:

图片

图片

可以看到 CreateML 的 Object Detection 其实是基于 Yolo V2 的,最先进的 Yolo 版本应该是 Yolo V7,但是生态最健全的应该还是 Yolo V5。

图片

在我的 M1 Pro 机器上大概需要训练 10h+,在 Intel 的笔记本上训练时间会更长。整个过程有点像「炼蛊」了,从 500 多 MB 的文件算出一个 80MB 的文件。

模型测试

训练完之后,你可以得到上面「尝鲜」中得到模型文件,大概它拖动任意文件进去,就可以测试模型的效果了:

图片

在 iOS 中使用的模型

官方的 Demo 可以参照这个例子:

​https://developer.apple.com/documentation/vision/recognizing_objects_in_live_capture​

个人用 SwiftUI 写了一个 Demo:

//
// ContentView.swift
// DemoProject
/
//

import SwiftUI
import Vision

class MyVNModel: ObservableObject {

static let shared: MyVNModel = MyVNModel()

@Published var parsedModel: VNCoreMLModel? = .none
var images: [UIImage]? = .none
var observationList: [[VNObservation]]? = .none

func applyModelToCgImage(image: CGImage) async throws -> [VNObservation] {
guard let parsedModel = parsedModel else {
throw EvaluationError.resourceNotFound("cannot find parsedModel")
}

let resp = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<[VNObservation], Error>) in
let requestHandler = VNImageRequestHandler(cgImage: image)
let request = VNCoreMLRequest(model: parsedModel) { request, error in
if let _ = error {
return
}
if let results = request.results {
continuation.resume(returning: results)
} else {
continuation.resume(throwing: EvaluationError.invalidExpression(
"cannot find observations in result"
))
}
}
#if targetEnvironment(simulator)
request.usesCPUOnly = true
#endif
do {
// Perform the text-recognition request.
try requestHandler.perform([request])
} catch {
continuation.resume(throwing: error)
}
}
return resp
}

init() {
Task(priority: .background) {
let urlPath = Bundle.main.url(forResource: "genshin2", withExtension: "mlmodelc")
guard let urlPath = urlPath else {
print("cannot find file genshin2.mlmodelc")
return
}

let config = MLModelConfiguration()
let modelResp = await withCheckedContinuation { continuation in
MLModel.load(contentsOf: urlPath, configuration: config) { result in
continuation.resume(returning: result)
}
}

let model = try { () -> MLModel in
switch modelResp {
case let .success(m):
return m
case let .failure(err):
throw err
}
}()

let parsedModel = try VNCoreMLModel(for: model)
DispatchQueue.main.async {
self.parsedModel = parsedModel
}
}
}

}

struct ContentView: View {

enum SheetType: Identifiable {
case photo
case confirm
var id: SheetType { self }
}

@State var showSheet: SheetType? = .none

@ObservedObject var viewModel: MyVNModel = MyVNModel.shared

var body: some View {
VStack {
Button {
showSheet = .photo
} label: {
Text("Choose Photo")
}
}
.sheet(item: $showSheet) { sheetType in
switch sheetType {
case .photo:
PhotoLibrary(handlePickedImage: { images in

guard let images = images else {
print("no images is selected")
return
}

var observationList: [[VNObservation]] = []
Task {
for image in images {

guard let cgImage = image.cgImage else {
throw EvaluationError.cgImageRetrievalFailure
}

let result = try await viewModel.applyModelToCgImage(image: cgImage)
print("model applied: (result)")

observationList.append(result)
}

DispatchQueue.main.async {
viewModel.images = images
viewModel.observationList = observationList
self.showSheet = .confirm
}
}

}, selectionLimit: 1)
case .confirm:
if let images = viewModel.images, let observationList = viewModel.observationList {
VNObservationConfirmer(imageList: images, observations: observationList, onSubmit: { _,_ in

})
} else {
Text("No Images (viewModel.images?.count ?? 0) (viewModel.observationList?.count ?? 0)")
}

}

}
.padding()
}
}

struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}

运行效果


图片

责任编辑:武晓燕 来源: ELab团队
相关推荐

2013-04-15 09:48:40

AndroidAVD错误处理方法

2010-04-23 09:51:12

Oracle工具

2011-08-19 17:44:01

2017-05-25 11:49:30

Android网络请求OkHttp

2010-09-15 17:29:20

无线局域网

2014-04-23 13:30:23

类簇iOS开发

2011-09-02 19:12:59

IOS应用Sqlite数据库

2021-12-09 09:52:36

云原生安全工具云安全

2018-07-30 08:20:39

编程语言Python集合

2011-06-15 15:16:54

Session

2017-03-16 20:00:17

Kafka设计原理达观产品

2013-07-19 12:52:50

iOS中BlockiOS开发学习

2010-04-30 11:10:32

Oracle Sql

2010-07-19 16:55:51

Telnet命令

2010-03-10 11:45:15

云计算

2012-02-13 14:10:11

MonoTouchiOS应用Visual Stud

2012-02-13 14:22:22

MonoTouchiOS应用Visual Stud

2021-01-07 09:35:49

Pythontqdm进度

2021-10-18 12:01:17

iOS自动化测试Trip

2011-09-02 19:24:20

SqliteIOS应用数据库
点赞
收藏

51CTO技术栈公众号