Variables traces using show_guts decorator
usefull to debug
Adaptaton from https://stackoverflow.com/questions/24165374/printing-a-functions-local-variable-names-and-values
Update to python 3.
import sys
import threading
def show_guts(f):
sentinel = object()
gutsdata = threading.local()
gutsdata.captured_locals = None
gutsdata.tracing = False
def trace_locals(frame, event, arg):
if event.startswith('c_'): # C code traces, no new hook
return
if event == 'call': # start tracing only the first call
if gutsdata.tracing:
return None
gutsdata.tracing = True
return trace_locals
if event == 'line': # continue tracing
return trace_locals
# event is either exception or return, capture locals, end tracing
gutsdata.captured_locals = frame.f_locals.copy()
return None
def wrapper(*args, **kw):
# preserve existing tracer, start our trace
old_trace = sys.gettrace()
sys.settrace(trace_locals)
retval = sentinel
try:
retval = f(*args, **kw)
finally:
# reinstate existing tracer, report, clean up
sys.settrace(old_trace)
for key, val in gutsdata.captured_locals.items():
print('{}: {!r}'.format(key, val))
if retval is not sentinel:
print('Returned: {!r}'.format(retval))
gutsdata.captured_locals = None
gutsdata.tracing = False
return retval
return wrapper
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
@show_guts
def get_score(current_classifications, original_classifications, target_indices, other_indices, penalty_weight):
'''
Function to return the score of the current classifications, penalizing changes
to other classes with an L2 norm.
Parameters:
current_classifications: the classifications associated with the current noise
original_classifications: the classifications associated with the original noise
target_indices: the index of the target class
other_indices: the indices of the other classes
penalty_weight: the amount that the penalty should be weighted in the overall score
'''
# Steps: 1) Calculate the change between the original and current classifications (as a tensor)
# by indexing into the other_indices you're trying to preserve, like in x[:, features].
# 2) Calculate the norm (magnitude) of changes per example.
# 3) Multiply the mean of the example norms by the penalty weight.
# This will be your other_class_penalty.
# Make sure to negate the value since it's a penalty!
# 4) Take the mean of the current classifications for the target feature over all the examples.
# This mean will be your target_score.
#### START CODE HERE ####
change_original_classification = (current_classifications[:,other_indices] - original_classifications[:,other_indices])
# Calculate the norm (magnitude) of changes per example and multiply by penalty weight
other_class_penalty = - torch.mean(torch.norm(change_original_classification, dim=1) * penalty_weight)
# Take the mean of the current classifications for the target feature
target_score = torch.mean(current_classifications)
#### END CODE HERE ####
return target_score + other_class_penalty
rows = 10
current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
# Must be 3
assert get_score(current_class, original_class, [1, 3] , [0, 2], 0.2).item() == 3