The Tensorflow.js example uses the universal sentence encoder to train two text classification models.
- An 'intent' classifier that classifies sentences into categories representing user intent for a query.
- A token tagger, that classifies tokens within a weather releated query to identify location related tokens.
Note: These instructions use yarn
, but you can use npm run
instead if you
do not have yarn
installed.
Install dependencies
yarn
There are four npm/yarn scripts listed in package.json for preparing the training data. Each writes out one of more new files.
The two scripts needed to train the intent classifier are:
yarn raw-to-csv
: Converts the raw data into a csv formatyarn csv-to-tensors
: Converts the strings in the CSV created in step 1 into tensors.
The two scripts needed to train the token tagger are:
yarn raw-to-tagged-tokens
: Extracts tokens from sentences in the original data and tags each token with a categoryyarn tokens-to-embeddings
: embeds the tokens from the queries using the universal sentence encoder and writes out a look-up-table.
You can run all four of these commands with
yarn prep-data
You only need to do this once. This process can take 2-5 minutes on the smaller data sets and up to an hour on the full data set. The output of these scripts will be written to the training/data
folder.
To train the intent classifier model run:
yarn train-intent
To train the token tagging model run:
yarn train-tagger
Each of these scripts take multiple options, look at training/train-intent.js
and training/train-tagger.js
for details.
These scripts will output model artifacts in the training/models
folder.
You can run all two of these commands with
yarn train
Once the models are trained you can use the following command to run the demo app
yarn workshop-app
git checkout load-tensor
async function loadIntentClassifer(url) {
if (intent == null) {
intent = await tf.loadLayersModel(url);
}
return intent;
}
/**
* Load a custom trained token tagger model.
* @param {string} name Type of model to load. Should be a key in modelUrls
*/
async function loadTagger(name) {
if (taggers[name] == null) {
const url = modelUrls[name];
try {
taggers[name] = await tf.loadLayersModel(url);
} catch (e) {
// Could not load that model. This is not necessarily an error
// as the user may not have trained all the available model types
console.log(`Could not load "${name}" model`);
}
}
return taggers[name];
}
/**
* Load a number of models to allow the browser to cache them.
*/
async function loadTaggerModel() {
const modelLoadPromises = Object.keys(modelUrls).map(loadTagger);
return await Promise.all([loadUSE(), ...modelLoadPromises]);
}
window.addEventListener('load', function() {
loadTaggerModel();
setupListeners();
});
git checkout add-intents
async function classify(sentences) {
const [use, intent, metadata] = await Promise.all(
[loadUSE(), loadIntentClassifer(DENSE_MODEL_URL), loadIntentMetadata()]);
const {
labels
} = metadata;
const activations = await use.embed(sentences);
const prediction = intent.predict(activations);
const predsArr = await prediction.array();
const preview = [predsArr[0].slice()];
preview.unshift(labels);
console.table(preview);
tf.dispose([activations, prediction]);
return predsArr[0];
}
async function getClassificationMessage(softmaxArr, inputText) {
const {
labels
} = await loadIntentMetadata();
const max = Math.max(...softmaxArr);
const maxIndex = softmaxArr.indexOf(max);
const intentLabel = labels[maxIndex];
if (max < THRESHOLD) {
return '¯\\_(ツ)_/¯';
} else {
let response;
switch (intentLabel) {
case 'GetWeather':
response = '⛅'
break;
case 'PlayMusic':
response = '🎵🎺🎵';
break;
default:
response = '?';
break;
}
return response;
}
}
async function sendMessage(inputText) {
...
// Classify the text
const classification = await classify([inputText]);
// Add the response to the chat window
const response = await getClassificationMessage(classification, inputText);
appendMessage(response, 'bot', msgId);
...
}
git checkout add-tagging
case 'GetWeather':
const model = "bidirectional-lstm";
var location = await tagMessage(inputText, model);
response = '⛅ ' + location.trim();
break;
/**
* Tokenize a sentence and tag the tokens.
*
* @param {string} sentence sentence to tag
* @param {string} model name of model to use
*
* @return {Object} dictionary of tokens, model outputs and embeddings
*/
async function tagTokens(sentence, model = 'bidirectional-lstm') {
const [use, tagger, metadata] =
await Promise.all([loadUSE(), loadTagger(model), loadMetadata(model)]);
const {
labels,
sequenceLength
} = metadata;
let tokenized = tokenizeSentence(sentence);
if (tokenized.length > sequenceLength) {
console.warn(
`Input sentence has more tokens than max allowed tokens ` +
`(${sequenceLength}). Extra tokens will be dropped.`);
}
tokenized = tokenized.slice(0, sequenceLength);
const activations = await use.embed(tokenized);
// get prediction
const prediction = tf.tidy(() => {
// Make an input tensor of [1, sequence_len, embedding_size];
const toPad = sequenceLength - tokenized.length;
const padTensors = tf.ones([toPad, EMBEDDING_DIM]);
const padded = activations.concat(padTensors);
const batched = padded.expandDims();
return tagger.predict(batched);
});
// Prediction data
let predsArr = (await prediction.array())[0];
// Add padding 'tokens' to the end of the values that will be displayed
// in the UI. These are there for illustration.
if (tokenized.length < sequenceLength) {
tokenized.push(labels[2]);
predsArr = predsArr.slice(0, tokenized.length);
}
// Add an extra activation to illustrate the padding inputs in the UI.
// This is added for illustration.
const displayActivations =
tf.tidy(() => activations.concat(tf.ones([1, EMBEDDING_DIM])));
const displayActicationsArr = await displayActivations.array();
tf.dispose([activations, prediction, displayActivations]);
return {
tokenized: tokenized,
tokenScores: predsArr,
tokenEmbeddings: displayActicationsArr,
};
}
async function tagMessage(inputText, model) {
if (inputText != null && inputText.length > 0) {
const result = await tagTokens(inputText, model);
const {
tokenized,
tokenScores,
tokenEmbeddings
} = result;
const metadata = await loadMetadata(model);
const {
labels
} = metadata;
const location = tokenScores
.map((scores, index) => {
const maxIndex = scores.indexOf(Math.max(...scores));
if (maxIndex === 1) {
return tokenized[index]
}
})
.join('\ ')
console.log(location);
return location;
}
}
git checkout add-weather-api
case 'GetWeather':
const model = "bidirectional-lstm";
var location = await tagMessage(inputText, model);
if (location.trim() != "") {
const weatherMessage = await getWeather(location);;
response = '⛅ ' + weatherMessage;
} else {
response = '⛅';
}
break;
async function getWeatherSearch(location) {
const response = await fetch(`https://cors-anywhere.herokuapp.com/https://www.metaweather.com/api/location/search/?query=${location.trim()}`);
const weatherSearch = response.json();
return weatherSearch;
}
async function getWeather(location) {
const weatherSearch = await getWeatherSearch(location);
if (weatherSearch.length > 0) {
const weatherResponse = await fetch(`https://cors-anywhere.herokuapp.com/https://www.metaweather.com/api/location/${weatherSearch[0].woeid}/`);
const weather = await weatherResponse.json()
return `The ${weather.location_type} of ${weather.title} is expecting ${weather.consolidated_weather[0].weather_state_name} today.`
} else {
return `I'm not smart enough to know weather data for ${location}`
}
}
git checkout display-tokenization
/**
* Render the tokens
*
* @param {string[]} tokens the tokens
* @param {Array.number[]} tokenScores model scores for each token
* @param {Array.number[]} tokenEmbeddings token embeddings
* @param {string} model name of model
*/
async function displayTokenization(tokens, tokenScores, tokenEmbeddings, model) {
const resultsDiv = document.createElement('div');
resultsDiv.classList = `tagging`;
resultsDiv.innerHTML = `<p class="model-type ${model}">${model}</p>`;
displayTokens(tokens, resultsDiv);
displayEmbeddingsPlot(tokenEmbeddings, resultsDiv);
displayTags(tokenScores, resultsDiv, model);
document.getElementById('taggings').prepend(resultsDiv);
}
/**
* Render the tokens.
*
* @param {string[]} tokens tokens to display
* @param {HTMLElement} parentEl parent element
*/
function displayTokens(tokens, parentEl) {
const tokensDiv = document.createElement('div');
tokensDiv.classList = `tokens`;
tokensDiv.innerHTML =
tokens.map(token => `<div class="token">${token}</div>`).join('\n');
parentEl.appendChild(tokensDiv);
}
const embeddingCol =
d3.scaleSequential(d3.interpolateSpectral).domain([-0.075, 0.075]);
embeddingCol.clamp(true);
/**
* Display an illustrative representation of the embeddings values
* @param {*} embeddings
* @param {*} parentEl
*/
function displayEmbeddingsPlot(embeddings, parentEl) {
const embeddingDiv = document.createElement('div');
embeddingDiv.classList = `embeddings`;
embeddingDiv.innerHTML =
embeddings
.map(embedding => {
// Note that this slice is arbitraty as the plot is only meant to
// be illustrative.
const embeddingValDivs = embedding.slice(0, 340).map(val => {
return `<div class="embVal" ` +
`style="background-color:${embeddingCol(val)} "` +
`title="${val}"` +
`></div>`;
});
return `<div class="embedding">` +
`${embeddingValDivs.join('\n')}</div>`;
})
.join('\n');
parentEl.appendChild(embeddingDiv);
}
/**
*
* @param {*} tokenScores
* @param {*} parentEl
* @param {*} modelName
*/
async function displayTags(tokenScores, parentEl, modelName) {
const metadata = await loadMetadata(modelName);
const {
labels
} = metadata;
let location = "";
const tagsDiv = document.createElement('div');
tagsDiv.classList = `tags`;
tagsDiv.innerHTML =
tokenScores
.map((scores, index) => {
const maxIndex = scores.indexOf(Math.max(...scores));
const token = labels[maxIndex];
const tokenScore = (scores[maxIndex] * 100).toPrecision(3);
return `<div class="tag ${token}">` +
` ${token.replace(/__/g, '')}<sup>` +
`${tokenScore}%</sup></div>`;
})
.join('\n');
if (location != "") {
appendMessage(location, 'bot', messageId - 1);
}
parentEl.appendChild(tagsDiv);
}
async function tagMessage(inputText, model) {
...
displayTokenization(tokenized, tokenScores, tokenEmbeddings, model);
...
return location;
}
git checkout add-speech-synthesis
function speak(message) {
let utterance = new SpeechSynthesisUtterance();
utterance.text = message;
utterance.rate = 5;
utterance.pitch = 2;
utterance.lang = "en-GB";
speechSynthesis.speak(utterance);
}
case 'GetWeather':
const model = "bidirectional-lstm";
var location = await tagMessage(inputText, model);
if (location.trim() != "") {
const weatherMessage = await getWeather(location);;
response = '⛅ ' + weatherMessage;
speak(weatherMessage);
} else {
response = '⛅';
}
git checkout add-speech-recognition
function ask() {
var SpeechRecognition = SpeechRecognition || webkitSpeechRecognition;
var recognition = new SpeechRecognition();
recognition.lang = "en-US";
recognition.interimResults = false;
recognition.start();
recognition.addEventListener('result', (e) => {
let last = e.results.length - 1;
let text = e.results[last][0].transcript;
console.log(text);
sendMessage(text);
recognition.stop();
});
}
function setupListeners() {
...
speech.addEventListener('click', ask, false);
}