import sys
import threading
def show_guts(f):
= object()
sentinel = threading.local()
gutsdata = None
gutsdata.captured_locals = False
gutsdata.tracing
def trace_locals(frame, event, arg):
if event.startswith('c_'): # C code traces, no new hook
return
if event == 'call': # start tracing only the first call
if gutsdata.tracing:
return None
= True
gutsdata.tracing return trace_locals
if event == 'line': # continue tracing
return trace_locals
# event is either exception or return, capture locals, end tracing
= frame.f_locals.copy()
gutsdata.captured_locals return None
def wrapper(*args, **kw):
# preserve existing tracer, start our trace
= sys.gettrace()
old_trace
sys.settrace(trace_locals)
= sentinel
retval try:
= f(*args, **kw)
retval finally:
# reinstate existing tracer, report, clean up
sys.settrace(old_trace)for key, val in gutsdata.captured_locals.items():
print('{}: {!r}'.format(key, val))
if retval is not sentinel:
print('Returned: {!r}'.format(retval))
= None
gutsdata.captured_locals = False
gutsdata.tracing
return retval
return wrapper
show_guts decorator
Adaptaton from https://stackoverflow.com/questions/24165374/printing-a-functions-local-variable-names-and-values
Update to python 3.
use example
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
@show_guts
def get_score(current_classifications, original_classifications, target_indices, other_indices, penalty_weight):
'''
Function to return the score of the current classifications, penalizing changes
to other classes with an L2 norm.
Parameters:
current_classifications: the classifications associated with the current noise
original_classifications: the classifications associated with the original noise
target_indices: the index of the target class
other_indices: the indices of the other classes
penalty_weight: the amount that the penalty should be weighted in the overall score
'''
# Steps: 1) Calculate the change between the original and current classifications (as a tensor)
# by indexing into the other_indices you're trying to preserve, like in x[:, features].
# 2) Calculate the norm (magnitude) of changes per example.
# 3) Multiply the mean of the example norms by the penalty weight.
# This will be your other_class_penalty.
# Make sure to negate the value since it's a penalty!
# 4) Take the mean of the current classifications for the target feature over all the examples.
# This mean will be your target_score.
#### START CODE HERE ####
= (current_classifications[:,other_indices] - original_classifications[:,other_indices])
change_original_classification # Calculate the norm (magnitude) of changes per example and multiply by penalty weight
= - torch.mean(torch.norm(change_original_classification, dim=1) * penalty_weight)
other_class_penalty # Take the mean of the current classifications for the target feature
= torch.mean(current_classifications)
target_score #### END CODE HERE ####
return target_score + other_class_penalty
= 10
rows = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class
# Must be 3
assert get_score(current_class, original_class, [1, 3] , [0, 2], 0.2).item() == 3
current_classifications: tensor([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]])
original_classifications: tensor([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]])
target_indices: [1, 3]
other_indices: [0, 2]
penalty_weight: 0.2
change_original_classification: tensor([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]])
other_class_penalty: tensor(-0.)
target_score: tensor(2.5000)
Returned: tensor(2.5000)
AssertionError: