View in Neuro-inspired AI Toolkit Documentation |
View source on GitHub |
4.3. Music prediction (JSB)¶
This tutorial shows how to:
Interpret musical notes as a temporal stream of spikes, where the location in time matters (temporal coding).
Build a spiking network that predicts the next musical chord based on the chords seen so far.
Generate and play an entire song based on one-step predictions.
4.3.1. Musical notes as spikes¶
A presence of a particular sound at a particular time location is what creates music, and what the common musical notation conveys through musical sheets. Alternatively, we can view the information on the presence of sounds or musical notes as spikes.
A typical piano keyboard has 88 keys (Wikimedia - Creative Commons):
Therefore, for each timestep of a muscial piece, we can encode the information of the musical notes played by sending spikes to an input layer with 88 inputs. For illustration, let’s use the JSB dataset from N. Boulanger-Lewandowski, et al., ICML 2012, that comprises chorales written by Johann Sebastian Bach. It can be loaded as follows:
[1]:
import neuroaikit.dataset.datasets as aid
x = aid.JSB()
The tuple x
contains the data split into: 229 train, 76 validation, and 77 test chorales:
[2]:
len(x), len(x[0]), len(x[1]), len(x[2])
[2]:
(3, 229, 76, 77)
Each music piece is a sequence of 88-dimensional vectors. The first training chorales is 129 time steps long:
[3]:
x[0][0].shape
[3]:
(129, 88)
We can visualize the first 20 time steps of this chorales as spikes. Note that the most commonly used notes are in the center of the keyboard and the y axis is automatically cropped:
[4]:
import matplotlib.pyplot as plt
import numpy as np
plt.scatter(*np.where(x[0][0][0:20,:]), marker='|')
plt.title('Spike trains')
plt.xlabel('Timestep')
plt.xlim([-0.5,20.5])
#plt.ylim([0,88])
plt.ylabel('Input index')
[4]:
Text(0, 0.5, 'Input index')
4.3.2. Train a network¶
Based on the input spikes observed at 88 inputs, we can build a network that would output notes’ predictions. In such case, the last layer should also have 88 outputs, as illustrated below:
Let’s import the required modules and build the network:
[5]:
import neuroaikit as ai
import neuroaikit.tf as aitf
import tensorflow as tf
The network includes one hidden layer of spiking neurons and a dense output layer without activation that outputs the raw logits:
[19]:
config = {'decay': 0.8, 'g': aitf.activations.leaky_rel}
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=[None, 88]))
model.add(aitf.layers.SNU(150, **config, return_sequences=True))
model.add(tf.keras.layers.Dense(88)) #output logits
Then, we use the binary cross-entropy loss, that is appropriate for training our model that outputs a series of separate per-note probabilities. The loss is configured to operate on raw logits:
[20]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), metrics=[])
For training, we need to have pairs of input and output vectors. We create a dataset generator, whose final step is to copy the input vectors twice with a shift of one timestep:
[8]:
ds = tf.data.Dataset.from_generator(lambda: x[0], tf.int32, output_shapes=[None,None])
ds = ds.map(lambda x: (tf.expand_dims(x[0:-1,:],0), tf.expand_dims(x[1:,:],0)))
Thus, the dataset generator returns pairs of sequences with 88 features:
[9]:
example = ds.as_numpy_iterator()
example_x, example_y = example.next()
example_x.shape, example_y.shape
[9]:
((1, 128, 88), (1, 128, 88))
Let’s train the model for 40 epochs:
[21]:
import time
time_start = time.time()
model.fit(ds, epochs=40)
print('Finished. Total time: {0:.1f} [s]'.format(time.time() - time_start))
Epoch 1/40
229/229 [==============================] - 3s 10ms/step - loss: 0.1430
Epoch 2/40
229/229 [==============================] - 2s 10ms/step - loss: 0.1074
Epoch 3/40
229/229 [==============================] - 2s 10ms/step - loss: 0.1038
Epoch 4/40
229/229 [==============================] - 2s 10ms/step - loss: 0.1021
Epoch 5/40
229/229 [==============================] - 2s 11ms/step - loss: 0.1008
Epoch 6/40
229/229 [==============================] - 3s 11ms/step - loss: 0.0995
Epoch 7/40
229/229 [==============================] - 3s 11ms/step - loss: 0.0988
Epoch 8/40
229/229 [==============================] - 3s 12ms/step - loss: 0.0982: 0
Epoch 9/40
229/229 [==============================] - 2s 11ms/step - loss: 0.0974
Epoch 10/40
229/229 [==============================] - 3s 11ms/step - loss: 0.0968
Epoch 11/40
229/229 [==============================] - 3s 13ms/step - loss: 0.0964
Epoch 12/40
229/229 [==============================] - 3s 12ms/step - loss: 0.0960
Epoch 13/40
229/229 [==============================] - 4s 17ms/step - loss: 0.0956
Epoch 14/40
229/229 [==============================] - 11s 46ms/step - loss: 0.0950
Epoch 15/40
229/229 [==============================] - 2s 10ms/step - loss: 0.0947
Epoch 16/40
229/229 [==============================] - 4s 17ms/step - loss: 0.0943
Epoch 17/40
229/229 [==============================] - 3s 15ms/step - loss: 0.0940
Epoch 18/40
229/229 [==============================] - 3s 15ms/step - loss: 0.0939
Epoch 19/40
229/229 [==============================] - 3s 15ms/step - loss: 0.0936
Epoch 20/40
229/229 [==============================] - 3s 14ms/step - loss: 0.0932
Epoch 21/40
229/229 [==============================] - 3s 14ms/step - loss: 0.0931
Epoch 22/40
229/229 [==============================] - 3s 13ms/step - loss: 0.0928: 0s - l
Epoch 23/40
229/229 [==============================] - 3s 11ms/step - loss: 0.0926
Epoch 24/40
229/229 [==============================] - 3s 11ms/step - loss: 0.0924
Epoch 25/40
229/229 [==============================] - 3s 12ms/step - loss: 0.0920
Epoch 26/40
229/229 [==============================] - 3s 11ms/step - loss: 0.0920
Epoch 27/40
229/229 [==============================] - 3s 12ms/step - loss: 0.0918
Epoch 28/40
229/229 [==============================] - 3s 12ms/step - loss: 0.0915
Epoch 29/40
229/229 [==============================] - 3s 13ms/step - loss: 0.0914
Epoch 30/40
229/229 [==============================] - 3s 14ms/step - loss: 0.0911
Epoch 31/40
229/229 [==============================] - 4s 16ms/step - loss: 0.0909
Epoch 32/40
229/229 [==============================] - 5s 21ms/step - loss: 0.0909
Epoch 33/40
229/229 [==============================] - 7s 29ms/step - loss: 0.0907
Epoch 34/40
229/229 [==============================] - 4s 19ms/step - loss: 0.0906
Epoch 35/40
229/229 [==============================] - 4s 19ms/step - loss: 0.0905
Epoch 36/40
229/229 [==============================] - 4s 17ms/step - loss: 0.0904
Epoch 37/40
229/229 [==============================] - 4s 18ms/step - loss: 0.0904
Epoch 38/40
229/229 [==============================] - 4s 16ms/step - loss: 0.0902
Epoch 39/40
229/229 [==============================] - 4s 16ms/step - loss: 0.0902
Epoch 40/40
229/229 [==============================] - 3s 15ms/step - loss: 0.0902
Finished. Total time: 156.5 [s]
We see that the loss keeps on decreasing, which means that the model is improving.
The accuracy of notes’ predictions for neural networks is commonly assesed using the loss value of the avergage negative log-likelihood (the lower the better). SNU-based networks predict the notes quite well, as illustrated in the figure below taken from S.Woźniak, et al., 2020, where more detailed explanations are provided. Please note that the loss reported in the figure involves an averaging approach that is different from the one used in the example above, and thus the loss values are not directly comparable.
4.3.3. Play one-step predictions¶
It’s interesting to hear what are the predictions of the network. We can take the output logits and threshold them to obtain the predicted notes. Then, we can play these notes.
To play the notes, we will make use of the fact that a note is charaterized by frequency:
\(f(n) = 2^{\frac{n-49}{12}} \times 440\) [Hz] - see https://en.wikipedia.org/wiki/Piano_key_frequencies
so that we can write a simple sine-wave-based synthesizer to play the music:
[22]:
from IPython.display import Audio
def play(song, note=0.25, rate=44100):
t = np.linspace(0, note, int(rate * note))
data = []
for s in range(song.shape[0]):
val = np.zeros_like(t)
for n in np.where(song[s,:] == 1)[0]:
f = 2**((n-49)/12) * 440
#val += np.sin(2 * np.pi * f * t) # "sine" sound
val += np.clip(2*np.sin(2 * np.pi * f * t),0,1) #"organ" sound
val *= np.clip(30.0*np.sin(np.pi * t/note), 0.0, 1.0) #smoothen to avoid 'cracks'
data.append(val)
return Audio(np.hstack(data), rate=rate, autoplay=True)
Let’s calculate the model next-step prediction’s for the 10th song and play them:
[28]:
logits = model.predict(np.expand_dims(x[0][10],0))
notes = (logits > 0)*1
play(notes[0,0:30,:])
[28]:
Here is the original:
[27]:
play(x[0][10][0:30,:])
[27]:
For comparison, predictions from an untrained model:
[30]:
config = {'decay': 0.8, 'g': aitf.activations.leaky_rel}
untrained_model = tf.keras.Sequential()
untrained_model.add(tf.keras.layers.InputLayer(input_shape=[None, 88]))
untrained_model.add(aitf.layers.SNU(150, **config, return_sequences=True))
untrained_model.add(tf.keras.layers.Dense(88)) #output logits
logits = untrained_model.predict(np.expand_dims(x[0][10],0))
notes = (logits > 0)*1
play(notes[0,0:30,:])
[30]:
The original sounds the best, which shows that music prediction is challenging in general. In comparison to an untrained model, the trained one captures some musical concepts.