Simulai builtin optimizers

Built-in Optimizers#

SpaRSA#

Source code in simulai/optimization/_builtin.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class SpaRSA:
    def __init__(
        self,
        lambd: float = None,
        alpha_0: float = None,
        epsilon: float = 1e-10,
        sparsity_tol: float = 1e-15,
        use_mean: bool = False,
        transform: callable = None,
    ) -> None:
        """Sparse Regression Algorithm

        Args:
            lambd (float): Quadratic regularization penalty.
            alpha_0 (float): Update step lenght.
            epsilon (float): Error tolerance.
            sparsity_tol (float): Sparsity tolerance. 
            use_mean (bool): Use mean for evaluating loss or not.
            transform (callable): A transformation to be applied to the data.  

        """

        self.lambd = lambd
        self.alpha_0 = alpha_0
        self.epsilon = epsilon
        self.sparsity_tol = sparsity_tol
        self.size = 1

        if transform is not None:
            self.transform = transform
        else:
            self.transform = self._bypass

        if use_mean is True:
            self.norm = lambda x: x / self.size
        else:
            self.norm = lambda x: x

        self.m = 0
        self.r = 0
        self.ref_step = 5
        self.lr_reduction = 1 / 2
        self.lr_increase = 3 / 2

        self.W = None
        self.target_data = None

    def _bypass(self, data: np.ndarray) -> np.ndarray:
        return data

    def _F_lambda(self, V_bar: np.ndarray = None) -> np.ndarray:
        residual = (
            np.linalg.norm(self._WV_bar(W=self.W, V_bar=V_bar) - self.target_data, None)
            ** 2
        )

        regularization = self.lambd * np.sum(np.linalg.norm(V_bar, 2, axis=1))

        return (1 / 2) * residual + self.lambd * regularization

    def R_alpha(
        self,
        W: np.ndarray = None,
        V_bar: np.ndarray = None,
        target_data: np.ndarray = None,
        alpha: float = 0,
    ) -> np.ndarray:
        return V_bar - alpha * W.T @ (self._WV_bar(W=W, V_bar=V_bar) - target_data)

    def _WV_bar(self, W: np.ndarray = None, V_bar: np.ndarray = None) -> np.ndarray:
        return W @ V_bar

    def _no_null_V_plus(
        self, R_alpha: np.ndarray = None, alpha: float = 0
    ) -> np.ndarray:
        return (1 - self.lambd * alpha / np.linalg.norm(R_alpha, None)) * R_alpha

    def V_plus(self, R_alpha: np.ndarray = None, alpha: float = None):
        # Zeroing lines according to the regularization criteria
        def _row_function(vector: np.ndarray = None) -> np.ndarray:
            norm = np.linalg.norm(vector, None)

            if norm <= self.lambd * alpha:
                return np.zeros(vector.shape)
            else:
                return self._no_null_V_plus(R_alpha=vector, alpha=alpha)

        rows = np.apply_along_axis(_row_function, 1, R_alpha)

        return rows

    def fit(
        self, input_data: np.ndarray = None, target_data: np.ndarray = None
    ) -> None:
        """

        Args:
            input_data (np.ndarray): Input data for training the model.
            target_data (np.ndarray): Target data for training the model. 

        Returns:


        """
        self.W = self.transform(data=input_data)
        self.target_data = target_data

        self.q = self.W.shape[-1]
        self.m = target_data.shape[-1]
        self.size = self.target_data.size

        V_0 = np.random.rand(self.q, self.m)

        V_k = V_0
        F_lambda_list = list()

        alpha = self.alpha_0
        stopping_criterion = False
        k = 0

        while not stopping_criterion:
            V_bar = V_k
            R_alpha = self.R_alpha(
                W=self.W, V_bar=V_bar, target_data=target_data, alpha=alpha
            )
            V_plus = self.V_plus(R_alpha=R_alpha, alpha=alpha)

            F_lambda_V_plus = self._F_lambda(V_bar=V_plus)
            F_lambda_V_bar = self._F_lambda(V_bar=V_bar)

            while F_lambda_V_plus >= F_lambda_V_bar:
                residual = F_lambda_V_plus - F_lambda_V_bar

                sys.stdout.write(
                    ("\ralpha: {}, discrepancy: {}").format(alpha, residual)
                )
                sys.stdout.flush()

                alpha = alpha * self.lr_reduction
                R_alpha = self.R_alpha(
                    W=self.W, V_bar=V_bar, target_data=target_data, alpha=alpha
                )
                V_plus = self.V_plus(R_alpha=R_alpha, alpha=alpha)

                F_lambda_V_plus = self._F_lambda(V_bar=V_plus)

            F_lambda = F_lambda_V_bar
            F_lambda_list.append(F_lambda)

            V_k = V_plus
            alpha = min(self.lr_increase * alpha, self.alpha_0)

            if k > self.ref_step:
                F_lambda_ref = F_lambda_list[-self.ref_step - 1]

                if np.abs(F_lambda - F_lambda_ref) / F_lambda_ref <= self.epsilon:
                    stopping_criterion = True

            sys.stdout.write(("\rresidual loss: {}").format(self.norm(F_lambda)))
            sys.stdout.flush()

            k += 1

        V_k = np.where(np.abs(V_k) < self.sparsity_tol, 0, V_k)

        return V_k

__init__(lambd=None, alpha_0=None, epsilon=1e-10, sparsity_tol=1e-15, use_mean=False, transform=None) #

Sparse Regression Algorithm

Parameters:

Name Type Description Default
lambd float

Quadratic regularization penalty.

None
alpha_0 float

Update step lenght.

None
epsilon float

Error tolerance.

1e-10
sparsity_tol float

Sparsity tolerance.

1e-15
use_mean bool

Use mean for evaluating loss or not.

False
transform callable

A transformation to be applied to the data.

None
Source code in simulai/optimization/_builtin.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    lambd: float = None,
    alpha_0: float = None,
    epsilon: float = 1e-10,
    sparsity_tol: float = 1e-15,
    use_mean: bool = False,
    transform: callable = None,
) -> None:
    """Sparse Regression Algorithm

    Args:
        lambd (float): Quadratic regularization penalty.
        alpha_0 (float): Update step lenght.
        epsilon (float): Error tolerance.
        sparsity_tol (float): Sparsity tolerance. 
        use_mean (bool): Use mean for evaluating loss or not.
        transform (callable): A transformation to be applied to the data.  

    """

    self.lambd = lambd
    self.alpha_0 = alpha_0
    self.epsilon = epsilon
    self.sparsity_tol = sparsity_tol
    self.size = 1

    if transform is not None:
        self.transform = transform
    else:
        self.transform = self._bypass

    if use_mean is True:
        self.norm = lambda x: x / self.size
    else:
        self.norm = lambda x: x

    self.m = 0
    self.r = 0
    self.ref_step = 5
    self.lr_reduction = 1 / 2
    self.lr_increase = 3 / 2

    self.W = None
    self.target_data = None

fit(input_data=None, target_data=None) #

Parameters:

Name Type Description Default
input_data ndarray

Input data for training the model.

None
target_data ndarray

Target data for training the model.

None

Returns:

Source code in simulai/optimization/_builtin.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def fit(
    self, input_data: np.ndarray = None, target_data: np.ndarray = None
) -> None:
    """

    Args:
        input_data (np.ndarray): Input data for training the model.
        target_data (np.ndarray): Target data for training the model. 

    Returns:


    """
    self.W = self.transform(data=input_data)
    self.target_data = target_data

    self.q = self.W.shape[-1]
    self.m = target_data.shape[-1]
    self.size = self.target_data.size

    V_0 = np.random.rand(self.q, self.m)

    V_k = V_0
    F_lambda_list = list()

    alpha = self.alpha_0
    stopping_criterion = False
    k = 0

    while not stopping_criterion:
        V_bar = V_k
        R_alpha = self.R_alpha(
            W=self.W, V_bar=V_bar, target_data=target_data, alpha=alpha
        )
        V_plus = self.V_plus(R_alpha=R_alpha, alpha=alpha)

        F_lambda_V_plus = self._F_lambda(V_bar=V_plus)
        F_lambda_V_bar = self._F_lambda(V_bar=V_bar)

        while F_lambda_V_plus >= F_lambda_V_bar:
            residual = F_lambda_V_plus - F_lambda_V_bar

            sys.stdout.write(
                ("\ralpha: {}, discrepancy: {}").format(alpha, residual)
            )
            sys.stdout.flush()

            alpha = alpha * self.lr_reduction
            R_alpha = self.R_alpha(
                W=self.W, V_bar=V_bar, target_data=target_data, alpha=alpha
            )
            V_plus = self.V_plus(R_alpha=R_alpha, alpha=alpha)

            F_lambda_V_plus = self._F_lambda(V_bar=V_plus)

        F_lambda = F_lambda_V_bar
        F_lambda_list.append(F_lambda)

        V_k = V_plus
        alpha = min(self.lr_increase * alpha, self.alpha_0)

        if k > self.ref_step:
            F_lambda_ref = F_lambda_list[-self.ref_step - 1]

            if np.abs(F_lambda - F_lambda_ref) / F_lambda_ref <= self.epsilon:
                stopping_criterion = True

        sys.stdout.write(("\rresidual loss: {}").format(self.norm(F_lambda)))
        sys.stdout.flush()

        k += 1

    V_k = np.where(np.abs(V_k) < self.sparsity_tol, 0, V_k)

    return V_k

BBI#

Bases: Optimizer

Source code in simulai/optimization/_builtin_pytorch.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
class BBI(Optimizer):

    def __init__(
        self,
        params: dict = None,
        lr: float = 1e-3,
        eps1: float = 1e-10,
        eps2: float = 1e-40,
        v0: float = 0,
        threshold0: int = 1000,
        threshold: int = 3000,
        deltaEn: float = 0.0,
        consEn: bool = True,
        n_fixed_bounces: int = 1,
    ) -> None:
         """Optimizer based on the BBI model of inflation.

         Args:
             params (iterable): iterable of parameters to optimize or dicts defining parameter groups
             lr (float): learning rate
             v0 (float): expected minimum of the potential (Delta V in the paper)
             threshold0 (int): threshold for fixed bounces (T0 in the paper)
             threshold1 (int): threshold for progress-dependent bounces (T1 in the paper)
             deltaEn (float): extra initial energy (delta E in the paper)
             consEn (bool): if True enforces energy conservation at every step
             n_fixed_bounces (int): number of bounces every T0 iterations (Nb in the paper)
         """

         defaults = dict(
            lr=lr,
            eps1=eps1,
            eps2=eps2,
            v0=v0,
            threshold=threshold,
            threshold0=threshold0,
            deltaEn=deltaEn,
            consEn=consEn,
            n_fixed_bounces=n_fixed_bounces,
         )
         self.energy = None
         self.min_loss = None
         self.iteration = 0
         self.deltaEn = deltaEn
         self.n_fixed_bounces = n_fixed_bounces
         self.consEn = consEn

         super(BBI, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(BBI, self).__setstate__(state)

    def step(self, closure: callable) -> torch.Tensor:

        """

        Args:
            closure (callable): A function which enclosures the loss
                evaluation. 
        Returns:
            torch.Tensor: The evaluation for the loss function.
        """

        loss = closure()  # .item()

        # initialization
        if self.iteration == 0:
            # define a random numbers generator, in order not to use the ambient seed and have random bounces even with the same ambient seed
            self.generator = torch.Generator(
                device=self.param_groups[0]["params"][0].device
            )
            self.generator.manual_seed(self.generator.seed() + 1)

            # Initial energy
            self.initV = loss - self.param_groups[0]["v0"]
            self.init_energy = self.initV + self.deltaEn

            # Some counters
            self.counter0 = 0
            self.fixed_bounces_performed = 0
            self.counter = 0

            self.min_loss = float("inf")

        for group in self.param_groups:
            V = loss - group["v0"]
            dt = group["lr"]
            eps1 = group["eps1"]
            eps2 = group["eps2"]
            threshold0 = group["threshold0"]
            threshold = group["threshold"]

            if V > eps2:
                EoverV = self.init_energy / V
                VoverE = V / self.init_energy

                # Now I check if loss and pi^2 are consistent with the initial value of the energy

                ps2_pre = torch.tensor(
                    0.0, device=self.param_groups[0]["params"][0].device
                )

                for p in group["params"]:
                    param_state = self.state[p]
                    d_p = p.grad.data
                    # Initialize in the direction of the gradient, with magnitude related to deltaE
                    if "momentum_buffer" not in param_state:
                        buf = param_state["momentum_buffer"] = -(
                            d_p / torch.norm(d_p)
                        ) * torch.sqrt(
                            torch.tensor(
                                ((self.init_energy**2) / self.initV) - self.initV
                            )
                        )
                    else:
                        buf = param_state["momentum_buffer"]

                    # compute the current pi^2 . Pre means that this is the value before the iteration step
                    ps2_pre += torch.dot(buf.view(-1), buf.view(-1))

                if self.consEn == True:
                    # Compare this \pi^2 with what it should have been if the energy was correct
                    ps2_correct = V * ((EoverV**2) - 1.0)

                    # Compute the rescaling factor, only if real
                    if torch.abs(ps2_pre - ps2_correct) < eps1:
                        self.rescaling_pi = 1.0
                    elif ps2_correct < 0.0:
                        self.rescaling_pi = 1.0
                    else:
                        self.rescaling_pi = torch.sqrt(((ps2_correct / (ps2_pre))))

                # Perform the optimization step
                if (self.counter != threshold) and (self.counter0 != threshold0):
                    for p in group["params"]:
                        if p.grad is None:
                            continue
                        d_p = p.grad.data
                        param_state = self.state[p]

                        if "momentum_buffer" not in param_state:
                            buf = param_state["momentum_buffer"] = torch.zeros_like(
                                p.data
                            )
                        else:
                            buf = param_state["momentum_buffer"]

                        # Here the rescaling of momenta to enforce conservation of energy
                        if self.consEn == True:
                            buf.mul_(self.rescaling_pi)

                        buf.add_(-0.5 * dt * (VoverE + EoverV) * d_p)
                        p.data.add_(dt * VoverE * buf)

                    # Updates counters
                    self.counter0 += 1
                    self.counter += 1
                    self.iteration += 1

                    # Checks progress
                    if V < self.min_loss:
                        self.min_loss = V
                        self.counter = 0
                # Bounces
                else:
                    # First we iterate once to compute pi^2, we randomly regenerate the directions, and we compute the new norm squared

                    ps20 = torch.tensor(
                        0.0, device=self.param_groups[0]["params"][0].device
                    )
                    ps2new = torch.tensor(
                        0.0, device=self.param_groups[0]["params"][0].device
                    )

                    for p in group["params"]:
                        param_state = self.state[p]

                        buf = param_state["momentum_buffer"]
                        ps20 += torch.dot(buf.view(-1), buf.view(-1))
                        new_buf = param_state["momentum_buffer"] = (
                            torch.rand(
                                buf.size(), device=buf.device, generator=self.generator
                            )
                            - 0.5
                        )
                        ps2new += torch.dot(new_buf.view(-1), new_buf.view(-1))

                    # Then rescale them
                    for p in group["params"]:
                        param_state = self.state[p]
                        buf = param_state["momentum_buffer"]
                        buf.mul_(torch.sqrt(ps20 / ps2new))

                    # Update counters
                    if self.counter0 == threshold0:
                        self.fixed_bounces_performed += 1
                        if self.fixed_bounces_performed < self.n_fixed_bounces:
                            self.counter0 = 0
                        else:
                            self.counter0 += 1
                    self.counter = 0
        return loss

__init__(params=None, lr=0.001, eps1=1e-10, eps2=1e-40, v0=0, threshold0=1000, threshold=3000, deltaEn=0.0, consEn=True, n_fixed_bounces=1) #

Optimizer based on the BBI model of inflation.

Parameters:

Name Type Description Default
params iterable

iterable of parameters to optimize or dicts defining parameter groups

None
lr float

learning rate

0.001
v0 float

expected minimum of the potential (Delta V in the paper)

0
threshold0 int

threshold for fixed bounces (T0 in the paper)

1000
threshold1 int

threshold for progress-dependent bounces (T1 in the paper)

required
deltaEn float

extra initial energy (delta E in the paper)

0.0
consEn bool

if True enforces energy conservation at every step

True
n_fixed_bounces int

number of bounces every T0 iterations (Nb in the paper)

1
Source code in simulai/optimization/_builtin_pytorch.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(
    self,
    params: dict = None,
    lr: float = 1e-3,
    eps1: float = 1e-10,
    eps2: float = 1e-40,
    v0: float = 0,
    threshold0: int = 1000,
    threshold: int = 3000,
    deltaEn: float = 0.0,
    consEn: bool = True,
    n_fixed_bounces: int = 1,
) -> None:
     """Optimizer based on the BBI model of inflation.

     Args:
         params (iterable): iterable of parameters to optimize or dicts defining parameter groups
         lr (float): learning rate
         v0 (float): expected minimum of the potential (Delta V in the paper)
         threshold0 (int): threshold for fixed bounces (T0 in the paper)
         threshold1 (int): threshold for progress-dependent bounces (T1 in the paper)
         deltaEn (float): extra initial energy (delta E in the paper)
         consEn (bool): if True enforces energy conservation at every step
         n_fixed_bounces (int): number of bounces every T0 iterations (Nb in the paper)
     """

     defaults = dict(
        lr=lr,
        eps1=eps1,
        eps2=eps2,
        v0=v0,
        threshold=threshold,
        threshold0=threshold0,
        deltaEn=deltaEn,
        consEn=consEn,
        n_fixed_bounces=n_fixed_bounces,
     )
     self.energy = None
     self.min_loss = None
     self.iteration = 0
     self.deltaEn = deltaEn
     self.n_fixed_bounces = n_fixed_bounces
     self.consEn = consEn

     super(BBI, self).__init__(params, defaults)

step(closure) #

Parameters:

Name Type Description Default
closure callable

A function which enclosures the loss evaluation.

required

Returns: torch.Tensor: The evaluation for the loss function.

Source code in simulai/optimization/_builtin_pytorch.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def step(self, closure: callable) -> torch.Tensor:

    """

    Args:
        closure (callable): A function which enclosures the loss
            evaluation. 
    Returns:
        torch.Tensor: The evaluation for the loss function.
    """

    loss = closure()  # .item()

    # initialization
    if self.iteration == 0:
        # define a random numbers generator, in order not to use the ambient seed and have random bounces even with the same ambient seed
        self.generator = torch.Generator(
            device=self.param_groups[0]["params"][0].device
        )
        self.generator.manual_seed(self.generator.seed() + 1)

        # Initial energy
        self.initV = loss - self.param_groups[0]["v0"]
        self.init_energy = self.initV + self.deltaEn

        # Some counters
        self.counter0 = 0
        self.fixed_bounces_performed = 0
        self.counter = 0

        self.min_loss = float("inf")

    for group in self.param_groups:
        V = loss - group["v0"]
        dt = group["lr"]
        eps1 = group["eps1"]
        eps2 = group["eps2"]
        threshold0 = group["threshold0"]
        threshold = group["threshold"]

        if V > eps2:
            EoverV = self.init_energy / V
            VoverE = V / self.init_energy

            # Now I check if loss and pi^2 are consistent with the initial value of the energy

            ps2_pre = torch.tensor(
                0.0, device=self.param_groups[0]["params"][0].device
            )

            for p in group["params"]:
                param_state = self.state[p]
                d_p = p.grad.data
                # Initialize in the direction of the gradient, with magnitude related to deltaE
                if "momentum_buffer" not in param_state:
                    buf = param_state["momentum_buffer"] = -(
                        d_p / torch.norm(d_p)
                    ) * torch.sqrt(
                        torch.tensor(
                            ((self.init_energy**2) / self.initV) - self.initV
                        )
                    )
                else:
                    buf = param_state["momentum_buffer"]

                # compute the current pi^2 . Pre means that this is the value before the iteration step
                ps2_pre += torch.dot(buf.view(-1), buf.view(-1))

            if self.consEn == True:
                # Compare this \pi^2 with what it should have been if the energy was correct
                ps2_correct = V * ((EoverV**2) - 1.0)

                # Compute the rescaling factor, only if real
                if torch.abs(ps2_pre - ps2_correct) < eps1:
                    self.rescaling_pi = 1.0
                elif ps2_correct < 0.0:
                    self.rescaling_pi = 1.0
                else:
                    self.rescaling_pi = torch.sqrt(((ps2_correct / (ps2_pre))))

            # Perform the optimization step
            if (self.counter != threshold) and (self.counter0 != threshold0):
                for p in group["params"]:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    param_state = self.state[p]

                    if "momentum_buffer" not in param_state:
                        buf = param_state["momentum_buffer"] = torch.zeros_like(
                            p.data
                        )
                    else:
                        buf = param_state["momentum_buffer"]

                    # Here the rescaling of momenta to enforce conservation of energy
                    if self.consEn == True:
                        buf.mul_(self.rescaling_pi)

                    buf.add_(-0.5 * dt * (VoverE + EoverV) * d_p)
                    p.data.add_(dt * VoverE * buf)

                # Updates counters
                self.counter0 += 1
                self.counter += 1
                self.iteration += 1

                # Checks progress
                if V < self.min_loss:
                    self.min_loss = V
                    self.counter = 0
            # Bounces
            else:
                # First we iterate once to compute pi^2, we randomly regenerate the directions, and we compute the new norm squared

                ps20 = torch.tensor(
                    0.0, device=self.param_groups[0]["params"][0].device
                )
                ps2new = torch.tensor(
                    0.0, device=self.param_groups[0]["params"][0].device
                )

                for p in group["params"]:
                    param_state = self.state[p]

                    buf = param_state["momentum_buffer"]
                    ps20 += torch.dot(buf.view(-1), buf.view(-1))
                    new_buf = param_state["momentum_buffer"] = (
                        torch.rand(
                            buf.size(), device=buf.device, generator=self.generator
                        )
                        - 0.5
                    )
                    ps2new += torch.dot(new_buf.view(-1), new_buf.view(-1))

                # Then rescale them
                for p in group["params"]:
                    param_state = self.state[p]
                    buf = param_state["momentum_buffer"]
                    buf.mul_(torch.sqrt(ps20 / ps2new))

                # Update counters
                if self.counter0 == threshold0:
                    self.fixed_bounces_performed += 1
                    if self.fixed_bounces_performed < self.n_fixed_bounces:
                        self.counter0 = 0
                    else:
                        self.counter0 += 1
                self.counter = 0
    return loss