TensorFlow.js: машинное обучение на JavaScript с доставкой в браузер
Кратко рассмотрены основные особенности недавно вышедшей JavaScript-версии популярного фреймворка машинного обучения от Google – TensorFlow.js.
Предыстория
В прошлом году компания Google представила библиотеку deeplearn.js, позволяющую пользователям непосредственно в браузере строить модели машинного обучения, используемые, например, для классификации изображений.
Эта библиотека была усовершенствована и представлена под названием TensorFlow.js: код стал более читаемым, функции заработали быстрее и пополнились решениями, позволяющими строить детализированные модели, а использование WebGL-технологии все так же позволяет на лету обрабатывать графические данные.
Устройство TensorFlow.js
Для обучения модели обычно как разработчику, так и конечному потребителю моделей машинного обучения, требуется установка соответствующих библиотек. Однако в TensorFlow.js интерфейсом является браузер, поэтому при подключении скриптового файла исчезает необходимость в установке библиотек и отслеживании зависимостей, все «доставляется на дом».
Строение библиотеки можно представить в виде оболочечной структуры – ядерного API и покрывающих его высокоуровневых слоев:
- Ops API (сокращенно от operations – низкоуровневые операции). Синтаксически эта составляющая близка к классическому TensorFlow с Python-интерфейсом.
- Layers API (высокоуровневые слои). Это API аналогично библиотеке Keras, сводящей действия в обучении к минималистичным наборам наиболее распространенных общих команд. Например, ранее мы показывали пример на Keras решения задачи предсказания сахарного диабета в 15 строк кода, в числе которых и загрузка данных, создание и обучение модели, проверка полученных результатов.
На основе высокоуровневого API можно построить собственный интерфейс для взаимодействий пользователя с моделью. Это открывает новые возможности для быстрого создания веб-приложений с захватом данных с сенсоров переносимых устройств (акселерометра, гироскопа, камеры, GPS и т.д.) и обучения моделей для подбора релевантного контента, исходя из поведения пользователя.
При этом автоматически обеспечивается безопасность, так как данные создаются и хранятся на стороне клиента. Кроме того, TensorFlow.js позволяет использовать предобученные модели, что крайне удобно для трудоемких задач, например, таких как распознавание образов (в качестве примера посмотрите код и демо для задачи распознавания расположенных перед веб-камерой объектов).
Как выглядит работа с TensorFlow.js
Ключевой структурой данных в TensorFlow.js являются тензоры – обобщения матриц на случаи потенциально бо́льших размерностей. Создание и обучение модели в TensorFlow.js возможно двумя способами.
Первый вариант создания модели – в результате запуска HTML-файла, в скриптовом блоке которого между тегами <script></script> помещен написанный вами код команд API фреймворка. Например, следующий код решает задачу линейной регрессии:
<html> <head> <!-- Загрузка TensorFlow.js --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.10.0"> </script> <!-- Расзместите ваш код между тэгами script. Вы также можете использовать внешний .js файл --> <script> // Замтетьте, что не используется никаких выражений 'import', // команда 'tf' доступна, благодаря указанию источника скрипта выше. // Определим модель линейной регрессии. Создадим простейшую нейронную сеть из одного слоя. const model = tf.sequential(); model.add(tf.layers.dense({units: 1, inputShape: [1]})); // Подготовим модель для обучения: обозначим функцию стоимости и оптимизатор. model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); // Создадим некоторые данные для обучения. const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]); const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]); // Обучаем модель по данным. model.fit(xs, ys).then(() => { // Используем модель для предсказания координаты точки. // Откройте этот файл в браузере и зайдите в консоль разработчика. model.predict(tf.tensor2d([5], [1, 1])).print(); }); </script> </head> <body> </body> </html>
Другой подход состоит в том, чтобы использовать не внешний ресурс, а добавить TensorFlow.js в ваш проект, используя yarn или npm. В этом случае описанный выше код будет содержаться в js-файле:
import * as tf from '@tensorflow/tfjs'; const model = tf.sequential(); model.add(tf.layers.dense({units: 1, inputShape: [1]})); model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]); const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]); model.fit(xs, ys).then(() => { model.predict(tf.tensor2d([5], [1, 1])).print(); });
Код идентичен тому, что запускался в HTML-блоке (комментарии проопущены), за исключением процедуры импорта, так как теперь мы используем локально установленную библиотеку. Для того, чтобы поиграть с этим кодом, воспользуйтесь ссылкой на ресурс с предустановленной библиотекой.
Примеры демо-приложений с обучением многослойных нейросетей и применением дополнительных библиотек с соответствующими примерами кода можно найти на сайте js.tensorflow.org. Видеопрезентация проекта с запуском упомянутых примеров расположена на YouTube.