show_guts decorator

import sys
import threading

def show_guts(f):
    sentinel = object()
    gutsdata = threading.local()
    gutsdata.captured_locals = None
    gutsdata.tracing = False

    def trace_locals(frame, event, arg):
        if event.startswith('c_'):  # C code traces, no new hook
            return 
        if event == 'call':  # start tracing only the first call
            if gutsdata.tracing:
                return None
            gutsdata.tracing = True
            return trace_locals
        if event == 'line':  # continue tracing
            return trace_locals

        # event is either exception or return, capture locals, end tracing
        gutsdata.captured_locals = frame.f_locals.copy()
        return None

    def wrapper(*args, **kw):
        # preserve existing tracer, start our trace
        old_trace = sys.gettrace()
        sys.settrace(trace_locals)

        retval = sentinel
        try:
            retval = f(*args, **kw)
        finally:
            # reinstate existing tracer, report, clean up
            sys.settrace(old_trace)
            for key, val in gutsdata.captured_locals.items():
                print('{}: {!r}'.format(key, val))
            if retval is not sentinel:
                print('Returned: {!r}'.format(retval))
            gutsdata.captured_locals = None
            gutsdata.tracing = False

        return retval

    return wrapper

use example

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
@show_guts
def get_score(current_classifications, original_classifications, target_indices, other_indices, penalty_weight):
    '''
    Function to return the score of the current classifications, penalizing changes
    to other classes with an L2 norm.
    Parameters:
        current_classifications: the classifications associated with the current noise
        original_classifications: the classifications associated with the original noise
        target_indices: the index of the target class
        other_indices: the indices of the other classes
        penalty_weight: the amount that the penalty should be weighted in the overall score
    '''
    # Steps: 1) Calculate the change between the original and current classifications (as a tensor)
    #           by indexing into the other_indices you're trying to preserve, like in x[:, features].
    #        2) Calculate the norm (magnitude) of changes per example.
    #        3) Multiply the mean of the example norms by the penalty weight. 
    #           This will be your other_class_penalty.
    #           Make sure to negate the value since it's a penalty!
    #        4) Take the mean of the current classifications for the target feature over all the examples.
    #           This mean will be your target_score.
    #### START CODE HERE ####
    change_original_classification = (current_classifications[:,other_indices] - original_classifications[:,other_indices])
    # Calculate the norm (magnitude) of changes per example and multiply by penalty weight
    other_class_penalty = - torch.mean(torch.norm(change_original_classification, dim=1) * penalty_weight)
    # Take the mean of the current classifications for the target feature
    target_score = torch.mean(current_classifications)
    #### END CODE HERE ####
    return target_score + other_class_penalty
rows = 10
current_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()
original_class = torch.tensor([[1] * rows, [2] * rows, [3] * rows, [4] * rows]).T.float()

# Must be 3
assert get_score(current_class, original_class, [1, 3] , [0, 2], 0.2).item() == 3
current_classifications: tensor([[1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.]])
original_classifications: tensor([[1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.],
        [1., 2., 3., 4.]])
target_indices: [1, 3]
other_indices: [0, 2]
penalty_weight: 0.2
change_original_classification: tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])
other_class_penalty: tensor(-0.)
target_score: tensor(2.5000)
Returned: tensor(2.5000)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-5-c7e77f2e4ae1> in <module>
      4 
      5 # Must be 3
----> 6 assert get_score(current_class, original_class, [1, 3] , [0, 2], 0.2).item() == 3

AssertionError: