Direct Ascent Synthesis: Revealing Hidden Generative Capabilities in Discriminative Models
I have a paper out! What!?!? Johno doesn’t write papers. True. But when Stanislav Fort discivered a neat trick that was one I’d also found back in the day, we got talking and figured it ought to be better documented so other people can use it too. I have to say: he did all the hard work! I sadly didn’t have time to play much, but did chip in a little. This blog post is a few of my ow thoughts, but you should read the paper first.
The TL;DR is as follows: instead of optimizing raw pixels, optimize a collection of image tensors at different resolutions that get resized and stacked together to form the final image. This turns out to have really neat regularization effects, and gives a really nice primitive for seeing what ‘natural’-ish images trigger various features in classifiers etc. This is pretty much the idea behind my 2021 ‘imstack’ stuff, but made cleaner and more general.
The other trick is to do some augmentations, critically adding some jitter (different crops) and noise. Once you have these pieces in place, you can optimize towards a text prompt with CLIP, or do style transfer, or trigger a specific class in a classification model… the possibilities are endless. Here’s the code to make the quintessential ‘jellyfish’ from an imagenet model for e.g. (colab)
def stack(x, large_resolution):
= 0.0
out for i,p in enumerate(x):
+= F.interpolate(p, size=(large_resolution, large_resolution), mode='bicubic' if resolutions[i] > 1 else 'nearest')
out return out
def raw_to_real_image(raw_image): return (torch.tanh(raw_image)+1.0)/2.0
= 336
large_resolution = range(1,large_resolution+1, 4)
resolutions = [torch.zeros(1,3,res,res).to("cuda") for res in resolutions]
image_layers for i,p in enumerate(image_layers): p.requires_grad = True
= torch.optim.SGD(image_layers, 0.005)
optimizer for step in tqdm(range(100)):
optimizer.zero_grad()= raw_to_real_image(stack(image_layers, large_resolution))
images = make_image_augmentations(images, count=16, jitter_scale=56, noise_scale=0.2)
images = -model(normalize(images))[:, 107].mean()
loss
loss.backward() optimizer.step()
I want to do a video explanation soon to capture more thoughts on this and show off more of what this technique can do. See also, Stanislav’s announcement post. The rest of this post is me rambling on some tangential bits that have come up since the paper was released.
Thoughts
WIP, TODO: - Describe the early days - Link my initial experiments - The benefits of curiosity driven independent researchers - ‘This doesn’t cite X’ - peer review pile-ons and the downsides of twitter - I’m going to stick to blog posts