embeddings

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

train models

create corpus

# fpaths = get_fpaths_year('2020')
fpaths = get_fpaths_subreddit('conspiracy')
comments = read_multi_comments_csvs(fpaths)
comments_clean = clean_comments(comments)
conv_to_lowerc       (2400, 5)  0:00:00.001485      
rm_punct             (2400, 5)  0:00:00.014857      
tokenize             (2400, 5)  0:00:00.003988      
rem_short_comments   (1695, 5)  0:00:00.002041      
class Corpus:
    """An iterator that yields sentences (lists of str)."""
    def __init__(self, docs):
        self.docs_clean = docs

    def __iter__(self):
        for doc in self.docs_clean:
            yield doc

source

Corpus

 Corpus (docs)

An iterator that yields sentences (lists of str).

corpus = Corpus(comments_clean['body'].to_list())

train model

def train_model(corpus,
              MIN_COUNT=5,
              SIZE=300,
              WORKERS=8,
              WINDOW=5,
              EPOCHS=5
              ):
    model = Word2Vec(
        corpus,
        min_count=MIN_COUNT,
        vector_size=SIZE,
        workers=WORKERS,
        window=WINDOW,
        epochs=EPOCHS
    )
    return model

source

train_model

 train_model (corpus, MIN_COUNT=5, SIZE=300, WORKERS=8, WINDOW=5,
              EPOCHS=5)
model = train_model(corpus)

load models

def load_models(model_names: list, models_dir: str='../models_test') -> dict:
    models = {}
    for name in model_names:
        try:
            models[name] = Word2Vec.load(f'{models_dir}/{name}.model')
        except FileNotFoundError:
            print(f"Model '{name}' not found in '{models_dir}'.")
    return models

source

load_models

 load_models (model_names:list, models_dir:str='../models_test')
models = load_models(['2019', '2020', 'Coronavirus', 'conspiracy'])
models
{'2019': <gensim.models.word2vec.Word2Vec>,
 '2020': <gensim.models.word2vec.Word2Vec>,
 'Coronavirus': <gensim.models.word2vec.Word2Vec>,
 'conspiracy': <gensim.models.word2vec.Word2Vec>}

align models

assert len(models['2019'].wv.key_to_index) != len(models['2020'].wv.key_to_index)
def intersection_align_gensim(m1, m2, words=None):
    """
    Intersect two gensim word2vec models, m1 and m2.
    Only the shared vocabulary between them is kept.
    If 'words' is set (as list or set), then the vocabulary is intersected with this list as well.
    Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2).
    These indices correspond to the new syn0 and syn0norm objects in both gensim models:
        -- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0
        -- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2
    The .vocab dictionary is also updated for each model, preserving the count but updating the index.
    """

    # Get the vocab for each model
    vocab_m1 = set(m1.wv.index_to_key)
    vocab_m2 = set(m2.wv.index_to_key)

    # Find the common vocabulary
    common_vocab = vocab_m1 & vocab_m2
    if words: common_vocab &= set(words)

    # If no alignment necessary because vocab is identical...
    if not vocab_m1 - common_vocab and not vocab_m2 - common_vocab:
        return (m1,m2)

    # Otherwise sort by frequency (summed for both)
    common_vocab = list(common_vocab)
    common_vocab.sort(key=lambda w: m1.wv.get_vecattr(w, "count") + m2.wv.get_vecattr(w, "count"), reverse=True)
    # print(len(common_vocab))

    # Then for each model...
    for m in [m1, m2]:
        # Replace old syn0norm array with new one (with common vocab)
        indices = [m.wv.key_to_index[w] for w in common_vocab]
        old_arr = m.wv.vectors
        new_arr = np.array([old_arr[index] for index in indices])
        m.wv.vectors = new_arr

        # Replace old vocab dictionary with new one (with common vocab)
        # and old index2word with new one
        new_key_to_index = {}
        new_index_to_key = []
        for new_index, key in enumerate(common_vocab):
            new_key_to_index[key] = new_index
            new_index_to_key.append(key)
        m.wv.key_to_index = new_key_to_index
        m.wv.index_to_key = new_index_to_key
        
        print(len(m.wv.key_to_index), len(m.wv.vectors))
        
    return (m1,m2)

source

intersection_align_gensim

 intersection_align_gensim (m1, m2, words=None)

Intersect two gensim word2vec models, m1 and m2. Only the shared vocabulary between them is kept. If ‘words’ is set (as list or set), then the vocabulary is intersected with this list as well. Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2). These indices correspond to the new syn0 and syn0norm objects in both gensim models: – so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0 – you can find the index of any word on the .index2word list: model.index2word.index(word) => 2 The .vocab dictionary is also updated for each model, preserving the count but updating the index.

def smart_procrustes_align_gensim(base_embed, other_embed, words=None):
    """
    Original script: https://gist.github.com/quadrismegistus/09a93e219a6ffc4f216fb85235535faf
    Procrustes align two gensim word2vec models (to allow for comparison between same word across models).
    Code ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <wleif@stanford.edu>.
        
    First, intersect the vocabularies (see `intersection_align_gensim` documentation).
    Then do the alignment on the other_embed model.
    Replace the other_embed model's syn0 and syn0norm numpy matrices with the aligned version.
    Return other_embed.
    If `words` is set, intersect the two models' vocabulary with the vocabulary in words (see `intersection_align_gensim` documentation).
    """

    # make sure vocabulary and indices are aligned
    in_base_embed, in_other_embed = intersection_align_gensim(base_embed, other_embed, words=words)

    # get the (normalized) embedding matrices
    base_vecs = in_base_embed.wv.get_normed_vectors()
    other_vecs = in_other_embed.wv.get_normed_vectors()

    # just a matrix dot product with numpy
    m = other_vecs.T.dot(base_vecs) 
    # SVD method from numpy
    u, _, v = np.linalg.svd(m)
    # another matrix operation
    ortho = u.dot(v) 
    # Replace original array with modified one, i.e. multiplying the embedding matrix by "ortho"
    other_embed.wv.vectors = (other_embed.wv.vectors).dot(ortho)    
    
    return other_embed

source

smart_procrustes_align_gensim

 smart_procrustes_align_gensim (base_embed, other_embed, words=None)

Original script: https://gist.github.com/quadrismegistus/09a93e219a6ffc4f216fb85235535faf Procrustes align two gensim word2vec models (to allow for comparison between same word across models). Code ported from HistWords https://github.com/williamleif/histwords by William Hamilton .

First, intersect the vocabularies (see intersection_align_gensim documentation). Then do the alignment on the other_embed model. Replace the other_embed model’s syn0 and syn0norm numpy matrices with the aligned version. Return other_embed. If words is set, intersect the two models’ vocabulary with the vocabulary in words (see intersection_align_gensim documentation).

smart_procrustes_align_gensim(models['2019'], models['2020'])
1710 1710
1710 1710
<gensim.models.word2vec.Word2Vec>
assert len(models['2019'].wv.key_to_index) == len(models['2020'].wv.vectors)

measure distances between types

def measure_distances(model_1, model_2):
    distances = pd.DataFrame(
        columns=('lex', 'dist_sem', "freq_1", "freq_2"),
        data=(
            #[w, spatial.distance.euclidean(model_1.wv[w], model_2.wv[w]),
            #[w, np.sum(model_1.wv[w] * model_2.wv[w]) / (np.linalg.norm(model_1.wv[w]) * np.linalg.norm(model_2.wv[w])),
            [w, spatial.distance.cosine(model_1.wv[w], model_2.wv[w]),
             model_1.wv.get_vecattr(w, "count"),
             model_2.wv.get_vecattr(w, "count")
             ] for w in model_1.wv.index_to_key
        )
    )
    return distances

source

measure_distances

 measure_distances (model_1, model_2)
distances = measure_distances(models['2019'], models['2020'])

distances\
    .sort_values('dist_sem', ascending=False)
lex dist_sem freq_1 freq_2
130 bot 0.179056 106 123
147 action 0.172512 98 112
1348 forget 0.171441 8 9
62 any 0.155413 243 272
94 am 0.153691 158 174
... ... ... ... ...
364 posts 0.000489 32 37
300 big 0.000456 41 47
227 life 0.000440 58 67
242 once 0.000392 55 62
265 another 0.000371 48 56

1710 rows × 4 columns


source

get_change_candidates

 get_change_candidates (k:int, distances:pandas.core.frame.DataFrame,
                        freq_min:int=100, propNouns:bool=True)

get nearest semantic neighbours

def get_nearest_neighbours_models(lex: str, freq_min: int, model_1, model_2, topn: int=100_000, k: int=10):
    nbs = []
    for count, model in enumerate([model_1, model_2]):
        for nb, sim in model.wv.most_similar(lex, topn=topn):
            if model.wv.get_vecattr(nb, 'count') > freq_min:
                d = {}
                d['Model'] = count + 1
                d['Word'] = nb
                d['SemDist'] = round(1 - sim, 2)
                d['Freq'] = model.wv.get_vecattr(nb, "count")
                d['vec'] = model.wv.get_vector(lex)
                nbs.append(d)
    nbs_df = pd.DataFrame(nbs)
    nbs_df = nbs_df\
        .query('Freq > @freq_min')\
        .groupby('Model', group_keys=False)\
        .apply(lambda group: group.nsmallest(k, 'SemDist'))
    nbs_model_1 = nbs_df.query('Model == 1')
    nbs_model_2 = nbs_df.query('Model == 2')
    return nbs_model_1, nbs_model_2

source

get_nearest_neighbours_models

 get_nearest_neighbours_models (lex:str, freq_min:int, model_1, model_2,
                                topn:int=100000, k:int=10)
nbs_1, nbs_2 = get_nearest_neighbours_models('good', 5, models['2019'], models['2020'])
/var/folders/gp/dw55jb3d3gl6jn22rscvxjm40000gn/T/ipykernel_48107/1648447668.py:18: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  .apply(lambda group: group.nsmallest(k, 'SemDist'))
nbs_1
Model Word SemDist Freq vec
0 1 her 0.0 166 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
1 1 me 0.0 332 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
2 1 much 0.0 150 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
3 1 as 0.0 487 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
4 1 people 0.0 293 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
5 1 one 0.0 280 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
6 1 had 0.0 178 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
7 1 now 0.0 150 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
8 1 all 0.0 374 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
9 1 him 0.0 146 [-0.00916857, 0.2121922, -0.025869183, 0.07142...
nbs_2
Model Word SemDist Freq vec
1709 2 well 0.0 144 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1710 2 also 0.0 184 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1711 2 me 0.0 371 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1712 2 ve 0.0 178 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1713 2 means 0.0 33 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1714 2 much 0.0 167 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1715 2 yeah 0.0 81 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1716 2 idea 0.0 41 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1717 2 sure 0.0 114 [-0.0038482083, 0.21741627, -0.03928126, 0.050...
1718 2 when 0.0 296 [-0.0038482083, 0.21741627, -0.03928126, 0.050...

semantic axes

models = load_models(['Coronavirus', 'conspiracy'])
def get_pole_avg(model_name: str, model: Word2Vec, lex, k=10):
    words = []
    vecs = []
    vecs.append(model.wv[lex])
    df = (pd.read_csv(f"../pole-words/{model_name}_{lex}.csv")
        .query('Include != "f"')
        .nlargest(k, 'SemSim')
    )
    pole_words = df['Word'].tolist()
    for word in pole_words:
        if word in model.wv:
            vecs.append(model.wv[word])
            words.append(word)
    pole_avg = np.mean(vecs, axis=0)
    return pole_avg

source

get_pole_avg

 get_pole_avg (model_name:str, model:gensim.models.word2vec.Word2Vec, lex,
               k=10)
get_pole_avg('Coronavirus', models['Coronavirus'], 'good').shape
(300,)
def make_sem_axis_avg(model_name: str, model: Word2Vec, pole_word_1: str, pole_word_2: str, k=10):
    pole_1_avg = get_pole_avg(model_name, model, pole_word_1, k)
    pole_2_avg = get_pole_avg(model_name, model, pole_word_2, k)
    sem_axis = pole_1_avg - pole_2_avg
    return sem_axis

source

make_sem_axis_avg

 make_sem_axis_avg (model_name:str, model:gensim.models.word2vec.Word2Vec,
                    pole_word_1:str, pole_word_2:str, k=10)
make_sem_axis_avg('Coronavirus', models['Coronavirus'], 'good', 'bad').shape
(300,)
def get_axis_sim(lex: str, pole_word_1: str, pole_word_2: str, model_name, model, k=10):
    sem_axis = make_sem_axis_avg(model_name, model, pole_word_1, pole_word_2, k)
    lex_vec = model.wv.get_vector(lex)
    sim_cos = 1 - spatial.distance.cosine(lex_vec, sem_axis)
    return sim_cos

source

get_axis_sim

 get_axis_sim (lex:str, pole_word_1:str, pole_word_2:str, model_name,
               model, k=10)
get_axis_sim('vaccines', 'good', 'bad', 'Coronavirus', models['Coronavirus'], k=10)
0.9172929190814479
def get_axis_sims(lexs: list, models: dict, pole_words: list, k=10):
    sims = []
    for lex in lexs:
        for name, model in models.items():
            sim = {}
            sim['model'] = name
            sim['lex'] = lex
            sim['sim'] = get_axis_sim(lex, pole_words[0], pole_words[1], name, model, k)
            sims.append(sim)
    sims_df = pd.DataFrame(sims)
    return sims_df

source

get_axis_sims

 get_axis_sims (lexs:list, models:dict, pole_words:list, k=10)
proj_sims = get_axis_sims(['vaccines', 'vaccine'], models, ['good', 'bad'])
proj_sims
model lex sim
0 Coronavirus vaccines 0.917293
1 conspiracy vaccines 0.990792
2 Coronavirus vaccine 0.933222
3 conspiracy vaccine 0.977039

source

aggregate_proj_sims

 aggregate_proj_sims (proj_sims)
proj_sims = aggregate_proj_sims(proj_sims)
proj_sims
model lex Coronavirus conspiracy SimDiff
0 vaccine 0.933222 0.977039 -0.043817
1 vaccines 0.917293 0.990792 -0.073499
proj_sims_melted = proj_sims.melt(id_vars=['lex', 'SimDiff'], var_name='model', value_name='SemSim')
proj_sims_melted
lex SimDiff model SemSim
0 vaccine -0.043817 Coronavirus 0.933222
1 vaccines -0.073499 Coronavirus 0.917293
2 vaccine -0.043817 conspiracy 0.977039
3 vaccines -0.073499 conspiracy 0.990792

source

plot_sem_axis

 plot_sem_axis (proj_sims_melted:pandas.core.frame.DataFrame,
                pole_words:list)
plot_sem_axis(proj_sims_melted, ['good', 'bad'])

maps of socio-semantic variation

models = load_models(['Coronavirus', 'conspiracy'])
smart_procrustes_align_gensim(models['Coronavirus'], models['conspiracy'])
1173 1173
1173 1173
<gensim.models.word2vec.Word2Vec>
def get_nbs_vecs(lex: str, model_name: str, model: Word2Vec, k=50):
    lex_vecs = []
    lex_d = {}
    lex_d['lex'] = lex
    lex_d['type'] = 'center'
    lex_d['subreddit'] = model_name
    lex_d['vec'] = model.wv.get_vector(lex)
    lex_vecs.append(lex_d)
    for nb, sim in model.wv.most_similar(lex, topn=k):
        lex_d = {}
        lex_d['lex'] = nb
        lex_d['type'] = 'nb'
        lex_d['sim'] = sim
        lex_d['subreddit'] = model_name
        lex_d['vec'] =  model.wv.get_vector(nb)
        lex_d['freq'] = model.wv.get_vecattr(nb, "count")
        lex_vecs.append(lex_d)
    lex_vecs_df = pd.DataFrame(lex_vecs)
    return lex_vecs_df

source

get_nbs_vecs

 get_nbs_vecs (lex:str, model_name:str,
               model:gensim.models.word2vec.Word2Vec, k=50)
get_nbs_vecs('vaccines', 'Coronavirus', models['Coronavirus'])
lex type subreddit vec sim freq
0 vaccines center Coronavirus [0.057227388, 0.10430427, 0.032422945, 0.06407... NaN NaN
1 health nb Coronavirus [0.114723645, 0.21284297, 0.064497545, 0.12715... 0.999692 46.0
2 per nb Coronavirus [0.090133496, 0.16385178, 0.045519162, 0.09794... 0.999669 29.0
3 deaths nb Coronavirus [0.11426731, 0.20661092, 0.06342163, 0.1284676... 0.999660 49.0
4 1 nb Coronavirus [0.11298724, 0.20978343, 0.061724722, 0.128311... 0.999638 54.0
5 us nb Coronavirus [0.11077853, 0.21411182, 0.06554103, 0.1271514... 0.999608 123.0
6 new nb Coronavirus [0.1211073, 0.22653188, 0.069263056, 0.1391144... 0.999603 76.0
7 years nb Coronavirus [0.09558547, 0.18013933, 0.05468028, 0.1110726... 0.999591 52.0
8 spread nb Coronavirus [0.07730215, 0.14945872, 0.04962611, 0.0888780... 0.999588 35.0
9 0 nb Coronavirus [0.0645965, 0.1220492, 0.035898782, 0.07554879... 0.999573 16.0
10 000 nb Coronavirus [0.075567916, 0.1435639, 0.040953316, 0.084523... 0.999569 29.0
11 7 nb Coronavirus [0.06406919, 0.12866612, 0.03618963, 0.0769648... 0.999565 17.0
12 world nb Coronavirus [0.08870022, 0.159735, 0.04859646, 0.09668462,... 0.999564 74.0
13 another nb Coronavirus [0.06684075, 0.12619431, 0.037108395, 0.077369... 0.999556 42.0
14 public nb Coronavirus [0.08594237, 0.16238208, 0.054486435, 0.097496... 0.999552 39.0
15 hours nb Coronavirus [0.055532847, 0.11181068, 0.032346416, 0.06532... 0.999545 15.0
16 control nb Coronavirus [0.0560417, 0.10506418, 0.030902166, 0.0663534... 0.999544 31.0
17 keep nb Coronavirus [0.08326553, 0.15533412, 0.04887768, 0.0957071... 0.999539 85.0
18 reported nb Coronavirus [0.07306963, 0.1333254, 0.045092944, 0.0796391... 0.999537 19.0
19 average nb Coronavirus [0.058386106, 0.10160504, 0.028476257, 0.06527... 0.999528 17.0
20 sars nb Coronavirus [0.054962628, 0.10801497, 0.030391451, 0.06244... 0.999521 14.0
21 ago nb Coronavirus [0.05723345, 0.1124895, 0.03595097, 0.06608445... 0.999521 29.0
22 science nb Coronavirus [0.06259699, 0.11655103, 0.03798895, 0.0715784... 0.999516 24.0
23 under nb Coronavirus [0.08245373, 0.15575409, 0.054234445, 0.097904... 0.999508 40.0
24 early nb Coronavirus [0.07038852, 0.12683387, 0.044630148, 0.081177... 0.999506 19.0
25 likely nb Coronavirus [0.06962872, 0.13197441, 0.043262962, 0.075477... 0.999506 36.0
26 due nb Coronavirus [0.052545387, 0.09022615, 0.028288025, 0.05750... 0.999492 25.0
27 schools nb Coronavirus [0.07024904, 0.13720992, 0.042491872, 0.087135... 0.999490 22.0
28 state nb Coronavirus [0.08707496, 0.16359739, 0.054630436, 0.101618... 0.999488 47.0
29 quality nb Coronavirus [0.05661944, 0.10371776, 0.03314033, 0.0622458... 0.999486 14.0
30 4 nb Coronavirus [0.057619605, 0.117240496, 0.036980513, 0.0707... 0.999483 26.0
31 help nb Coronavirus [0.074597746, 0.13663343, 0.040477894, 0.08401... 0.999481 34.0
32 available nb Coronavirus [0.050409965, 0.097474575, 0.028630486, 0.0579... 0.999480 13.0
33 america nb Coronavirus [0.055805784, 0.09872247, 0.029724134, 0.06394... 0.999477 30.0
34 response nb Coronavirus [0.04918535, 0.08948149, 0.022905065, 0.054223... 0.999472 12.0
35 high nb Coronavirus [0.10157195, 0.19419962, 0.06280887, 0.1199198... 0.999469 39.0
36 10 nb Coronavirus [0.08214546, 0.14870152, 0.04579945, 0.0888806... 0.999461 32.0
37 am nb Coronavirus [0.096367285, 0.18651997, 0.044814415, 0.11034... 0.999456 102.0
38 city nb Coronavirus [0.06082973, 0.11749765, 0.034887813, 0.073830... 0.999453 19.0
39 months nb Coronavirus [0.09699332, 0.17284966, 0.057289552, 0.105755... 0.999451 44.0
40 currently nb Coronavirus [0.062632866, 0.12303838, 0.03613432, 0.074787... 0.999444 19.0
41 already nb Coronavirus [0.09850482, 0.1830824, 0.059671603, 0.1124184... 0.999443 49.0
42 fact nb Coronavirus [0.063221715, 0.123731576, 0.04299859, 0.07519... 0.999428 44.0
43 places nb Coronavirus [0.058041047, 0.11331363, 0.038818423, 0.06794... 0.999427 18.0
44 provide nb Coronavirus [0.04588933, 0.095273994, 0.023019709, 0.05869... 0.999418 14.0
45 positive nb Coronavirus [0.067284316, 0.12324461, 0.040740743, 0.07202... 0.999414 19.0
46 link nb Coronavirus [0.04832779, 0.08898213, 0.022611374, 0.052865... 0.999399 27.0
47 based nb Coronavirus [0.073736355, 0.14148337, 0.03525661, 0.082236... 0.999397 24.0
48 20 nb Coronavirus [0.09155286, 0.18733306, 0.051601738, 0.109539... 0.999391 28.0
49 than nb Coronavirus [0.14099781, 0.26505557, 0.09407188, 0.1668624... 0.999377 124.0
50 open nb Coronavirus [0.08959826, 0.16488104, 0.051867336, 0.098772... 0.999374 36.0
nbs_vecs = pd.concat([get_nbs_vecs('vaccines', model_name, model, k=750) for model_name, model in models.items()])
nbs_vecs['vec'].iloc[0].shape
(300,)
def dim_red_nbs_vecs(nbs_vecs: pd.DataFrame, perplexity=50, n_iter=1000):
    Y_tsne = TSNE(
            perplexity=perplexity,
            method='exact',
            init='pca',
            verbose=False,
            learning_rate='auto',
            n_iter=n_iter
        ).fit_transform(np.array(list(nbs_vecs['vec'])))
    nbs_vecs['x_tsne'] = Y_tsne[:, [0]]
    nbs_vecs['y_tsne'] = Y_tsne[:, [1]]
    return nbs_vecs

source

dim_red_nbs_vecs

 dim_red_nbs_vecs (nbs_vecs:pandas.core.frame.DataFrame, perplexity=50,
                   n_iter=1000)
nbs_vecs_dimred = dim_red_nbs_vecs(nbs_vecs, n_iter=250)
print(nbs_vecs_dimred['x_tsne'].iloc[0].shape, nbs_vecs_dimred['y_tsne'].iloc[0].shape)
() ()

common neighbours

nbs_vecs = dim_red_nbs_vecs(nbs_vecs, perplexity=0.1, n_iter=250)
nbs_sim = (nbs_vecs
    .groupby('subreddit')
    .apply(lambda df: df.nlargest(10, 'sim'))
    .reset_index(drop=True)
)
/var/folders/gp/dw55jb3d3gl6jn22rscvxjm40000gn/T/ipykernel_48107/1291324649.py:3: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  .apply(lambda df: df.nlargest(10, 'sim'))
chart_sims = (alt.Chart(nbs_sim).mark_text().encode(
        x='x_tsne:Q',
        y='y_tsne:Q',
        text='lex',
        color='subreddit:N'
    ))

chart_sims

differences in neighbours

nbs_vecs = dim_red_nbs_vecs(nbs_vecs, perplexity=70, n_iter=250)
nbs_diff = nbs_vecs.drop_duplicates(subset='lex', keep=False)
nbs_diff = (nbs_diff
    .groupby('subreddit')
    .apply(lambda df: df.nlargest(20, 'sim'))
    .reset_index(drop=True)
)
/var/folders/gp/dw55jb3d3gl6jn22rscvxjm40000gn/T/ipykernel_48107/628690727.py:4: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  .apply(lambda df: df.nlargest(20, 'sim'))
chart_diffs = (alt.Chart(nbs_diff).mark_text().encode(
        x='x_tsne:Q',
        y='y_tsne:Q',
        text='lex:N',
        color='subreddit:N'
    ))


chart_diffs