-
Notifications
You must be signed in to change notification settings - Fork 643
Expand file tree
/
Copy pathoptimizer.py
More file actions
187 lines (152 loc) · 8.4 KB
/
optimizer.py
File metadata and controls
187 lines (152 loc) · 8.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from __future__ import absolute_import
import numpy as np
from keras import backend as K
from .callbacks import Print
from .grad_modifiers import get
from .utils import utils
_PRINT_CALLBACK = Print()
def _identity(x):
return x
class Optimizer(object):
def __init__(self, input_tensor, losses, input_range=(0, 255), wrt_tensor=None, norm_grads=True):
"""Creates an optimizer that minimizes weighted loss function.
Args:
input_tensor: An input tensor of shape: `(samples, channels, image_dims...)` if `image_data_format=
channels_first` or `(samples, image_dims..., channels)` if `image_data_format=channels_last`.
losses: List of ([Loss](vis.losses.md#Loss), weight) tuples.
input_range: Specifies the input range as a `(min, max)` tuple. This is used to rescale the
final optimized input to the given range. (Default value=(0, 255))
wrt_tensor: Short for, with respect to. This instructs the optimizer that the aggregate loss from `losses`
should be minimized with respect to `wrt_tensor`.
`wrt_tensor` can be any tensor that is part of the model graph. Default value is set to None
which means that loss will simply be minimized with respect to `input_tensor`.
norm_grads: True to normalize gradients. Normalization avoids very small or large gradients and ensures
a smooth gradient gradient descent process. If you want the actual gradient
(for example, visualizing attention), set this to false.
"""
self.input_tensor = input_tensor
self.input_range = input_range
self.loss_names = []
self.loss_functions = []
self.wrt_tensor = self.input_tensor if wrt_tensor is None else wrt_tensor
if self.input_tensor is self.wrt_tensor:
self.wrt_tensor_is_input_tensor = True
self.wrt_tensor = K.identity(self.wrt_tensor)
else:
self.wrt_tensor_is_input_tensor = False
overall_loss = None
for loss, weight in losses:
# Perf optimization. Don't build loss function with 0 weight.
if weight != 0:
loss_fn = weight * loss.build_loss()
overall_loss = loss_fn if overall_loss is None else overall_loss + loss_fn
self.loss_names.append(loss.name)
self.loss_functions.append(loss_fn)
# Compute gradient of overall with respect to `wrt` tensor.
if self.wrt_tensor_is_input_tensor:
grads = K.gradients(overall_loss, self.input_tensor)[0]
else:
grads = K.gradients(overall_loss, self.wrt_tensor)[0]
if norm_grads:
grads = K.l2_normalize(grads)
# The main function to compute various quantities in optimization loop.
self.compute_fn = K.function([self.input_tensor, K.learning_phase()],
self.loss_functions + [overall_loss, grads, self.wrt_tensor])
def _rmsprop(self, grads, cache=None, decay_rate=0.95):
"""Uses RMSProp to compute step from gradients.
Args:
grads: numpy array of gradients.
cache: numpy array of same shape as `grads` as RMSProp cache
decay_rate: How fast to decay cache
Returns:
A tuple of
step: numpy array of the same shape as `grads` giving the step.
Note that this does not yet take the learning rate into account.
cache: Updated RMSProp cache.
"""
if cache is None:
cache = np.zeros_like(grads)
cache = decay_rate * cache + (1 - decay_rate) * grads ** 2
step = -grads / np.sqrt(cache + K.epsilon())
return step, cache
def _get_seed_input(self, seed_input):
"""Creates a random `seed_input` if None. Otherwise:
- Ensures batch_size dim on provided `seed_input`.
- Shuffle axis according to expected `image_data_format`.
"""
desired_shape = (1, ) + K.int_shape(self.input_tensor)[1:]
if seed_input is None:
return utils.random_array(desired_shape, mean=np.mean(self.input_range),
std=0.05 * (self.input_range[1] - self.input_range[0]))
# Add batch dim if needed.
if len(seed_input.shape) != len(desired_shape):
seed_input = np.expand_dims(seed_input, 0)
# Only possible if channel idx is out of place.
if seed_input.shape[-1] != desired_shape[-1] and \
seed_input.shape[1] != desired_shape[1]:
seed_input = np.moveaxis(seed_input, -1, 1)
return seed_input.astype(K.floatx())
def minimize(self, seed_input=None, max_iter=200,
input_modifiers=None, grad_modifier=None,
callbacks=None, verbose=True):
"""Performs gradient descent on the input image with respect to defined losses.
Args:
seed_input: An N-dim numpy array of shape: `(samples, channels, image_dims...)` if `image_data_format=
channels_first` or `(samples, image_dims..., channels)` if `image_data_format=channels_last`.
Seeded with random noise if set to None. (Default value = None)
max_iter: The maximum number of gradient descent iterations. (Default value = 200)
input_modifiers: A list of [InputModifier](vis.input_modifiers.md#inputmodifier) instances specifying
how to make `pre` and `post` changes to the optimized input during the optimization process.
`pre` is applied in list order while `post` is applied in reverse order. For example,
`input_modifiers = [f, g]` means that `pre_input = g(f(inp))` and `post_input = f(g(inp))`
grad_modifier: gradient modifier to use. See [grad_modifiers](vis.grad_modifiers.md). If you don't
specify anything, gradients are unchanged. (Default value = None)
callbacks: A list of [OptimizerCallback](vis.callbacks.md#optimizercallback) instances to trigger.
verbose: Logs individual losses at the end of every gradient descent iteration.
Very useful to estimate loss weight factor(s). (Default value = True)
Returns:
The tuple of `(optimized input, grads with respect to wrt, wrt_value)` after gradient descent iterations.
"""
seed_input = self._get_seed_input(seed_input)
input_modifiers = input_modifiers or []
grad_modifier = _identity if grad_modifier is None else get(grad_modifier)
callbacks = callbacks or []
if verbose:
callbacks.append(_PRINT_CALLBACK)
cache = None
best_loss = float('inf')
best_input = None
grads = None
wrt_value = None
for i in range(max_iter):
# Apply modifiers `pre` step
for modifier in input_modifiers:
seed_input = modifier.pre(seed_input)
# 0 learning phase for 'test'
computed_values = self.compute_fn([seed_input, 0])
losses = computed_values[:len(self.loss_names)]
named_losses = list(zip(self.loss_names, losses))
overall_loss, grads, wrt_value = computed_values[len(self.loss_names):]
# TODO: theano grads shape is inconsistent for some reason. Patch for now and investigate later.
if grads.shape != wrt_value.shape:
grads = np.reshape(grads, wrt_value.shape)
# Apply grad modifier.
grads = grad_modifier(grads)
# Trigger callbacks
for c in callbacks:
c.callback(i, named_losses, overall_loss, grads, wrt_value)
# Gradient descent update.
# It only makes sense to do this if wrt_tensor is input_tensor. Otherwise shapes wont match for the update.
if self.wrt_tensor_is_input_tensor:
step, cache = self._rmsprop(grads, cache)
seed_input += step
# Apply modifiers `post` step
for modifier in reversed(input_modifiers):
seed_input = modifier.post(seed_input)
if overall_loss < best_loss:
best_loss = overall_loss.copy()
best_input = seed_input.copy()
# Trigger on_end
for c in callbacks:
c.on_end()
return utils.deprocess_input(best_input[0], self.input_range), grads, wrt_value