RNA/ADT interpretability tutorial

This tutorial shows how to run Champollion on a processed RNA/ADT dataset and then inspect the learned interaction matrix A. The goal is to illustrate the main steps of the workflow and the main result objects, not to provide a full analysis recipe.

The dataset used here is already preprocessed. For guidance on preparing a new dataset, see the Data Inputs section of the documentation.

[1]:
import warnings
warnings.filterwarnings("ignore")
[2]:
from pathlib import Path

import muon as mu
import numpy as np
import pandas as pd
from gprofiler import GProfiler
from IPython.display import display

from champollion import Champollion
from champollion.plot import plot_aggregated_transport_plan, top_interactions_bar
[KeOps] Warning : CUDA libraries not found or could not be loaded; Switching to CPU only.

Configuration

[8]:
DATA_PATH = Path(
    "../semi-paired/data/cite_opp_tuto_processed.h5mu"
)

RNA_MOD = "rna"
ADT_MOD = "adt"
CELLTYPE_KEY = "celltype"

Prepare the tutorial inputs

This dataset is fully paired, but in this tutorial we split it to mimic the semi-paired setting Champollion is designed for. This setup detail is made optional as it is not central to the point of this tutorial. For your own data and preprocessing choices, see the Data Inputs section of the documentation.

After this preparation step, we will have three objects: one bridge MuData object used for fitting, plus one pseudo-unpaired RNA object and one pseudo-unpaired ADT object used for transport and interpretation.

[9]:
if not DATA_PATH.exists():
    raise FileNotFoundError(
        f"Could not find {DATA_PATH}. Update DATA_PATH for your local data location."
    )

mdata = mu.read_h5mu(DATA_PATH)
print(mdata)
print("modalities:", list(mdata.mod.keys()))
MuData object with n_obs × n_vars = 10966 × 14087
  obs:  'celltype', 'fine_celltype', 'paired_mask'
  uns:  'dataset_name'
  2 modalities
    adt:        10966 × 134
      obs:      'batch', 'Site', 'DonorNumber', 'celltype', 'fine_celltype'
      var:      'feature_types', 'gene_id', 'variability_score', 'mean', 'std'
      obsm:     'X_pca', 'prior_data'
      layers:   'counts', 'scaled'
    rna:        10966 × 13953
      obs:      'batch', 'Site', 'DonorNumber', 'celltype', 'fine_celltype', 'n_genes'
      var:      'feature_types', 'gene_id', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'variability_score', 'mean', 'std'
      uns:      'hvg', 'log1p'
      obsm:     'X_pca', 'prior_data'
      layers:   'counts', 'scaled'
modalities: ['adt', 'rna']
[11]:
PAIRED_MASK_KEY = "paired_mask"

DEBUG = False
N_HVG = 250 if DEBUG else 4000
DEBUG_TRAIN_CELLS = 60
DEBUG_TEST_CELLS = 80

coding_genes = ["TNFRSF4", "TNFRSF14", "TNFRSF9", "CD52", "CD58", "CD2", "CD101", "FCGR1A", "CD1D", "CD1C", "FCER1A", "SLAMF6", "CD48", "SLAMF7", "CD244", "FCGR2A", "FCGR3A", "SELL", "PTPRC", "CR2", "CR1", "CD8A", "DPP4", "ITGA6", "ITGA4", "CD28", "CTLA4", "ICOS", "CCR4", "CX3CR1", "CCR2", "CCR5", "CD47", "BTLA", "TIGIT", "CD86", "TFRC", "CD38", "IL7R", "ITGA1", "CD14", "CSF1R", "CD83", "HLA-A", "HLA-E", "HLA-C", "HLA-B", "HLA-DRA", "NT5E", "CD24", "IFNGR1", "CCR6", "CD36", "CD72", "ENG", "IL2RA", "ITGB1", "NRP1", "FAS", "ENTPD1", "CD81", "CD44", "CD82", "MS4A1", "CD5", "NCAM1", "CD3G", "CXCR5", "CD9", "CD27", "LAG3", "CD4", "CD163", "CLEC4C", "KLRG1", "KLRB1", "CD69", "KLRD1", "KLRK1", "ITGB7", "CD63", "SELPLG", "LAMP1", "TRAC", "IGHD", "IGHM", "ANPEP", "IL4R", "CD19", "ITGAL", "ITGAM", "ITGAX", "ITGAE", "GP1BA", "CD79B", "PECAM1", "CD7", "CD226", "FCER2", "ICAM1", "CD22", "PVR", "NECTIN2", "C5AR1", "SIGLEC7", "CD33", "LILRB1", "NCR1", "SIRPA", "THBD", "CD93", "CD40", "ITGB2", "IL2RB", "TNFRSF13C", "IL3RA", "CD40LG"]

CELLTYPE_ORDER = ["HSC", "Lymph prog", "G/M prog", "MK/E prog", "Proerythroblast", "Erythroblast", "Normoblast", "Reticulocyte", "CD14+ Mono", "CD16+ Mono", "cDC2", "pDC", "Naive CD20+ B", "Transitional B", "B1 B", "Plasma cell", "Plasmablast", "CD4+ T naive", "CD4+ T", "CD4+ T activated", "CD8+ T naive", "CD8+ T", "ILC1", "NK", "MAIT"]

def get_hvg_names(adata, n_features=4000, keep_features=None):
    """Return HVGs, optionally forcing a set of biologically linked genes to stay."""
    if "variability_score" not in adata.var:
        raise KeyError(
            "Expected adata.var['variability_score'] in the processed RNA object."
        )

    var_annot = adata.var.sort_values(ascending=False, by="variability_score")
    n_features = min(n_features, len(var_annot))
    min_score = var_annot["variability_score"].iloc[n_features - 1]
    mask = adata.var["variability_score"] >= min_score

    if keep_features is not None:
        keep_features = [gene for gene in keep_features if gene in mask.index]
        mask.loc[keep_features] = True

    return list(adata.var_names[mask])


gene_names = get_hvg_names(
    mdata.mod[RNA_MOD], n_features=N_HVG, keep_features=coding_genes
)

paired_mask = np.asarray(mdata.obs[PAIRED_MASK_KEY].to_numpy(), dtype=bool)
celltypes = mdata.obs[CELLTYPE_KEY].astype(str).to_numpy()

def stratified_indices(mask, labels, n=None, random_state=0):
    """Return indices within mask, optionally stratified down to n cells."""
    idx = np.flatnonzero(mask)
    if n is None or n >= len(idx):
        return idx

    rng = np.random.default_rng(random_state)
    selected = []
    present_labels = [
        label for label in pd.unique(labels[idx]) if np.any(labels[idx] == label)
    ]
    per_group = max(1, int(np.ceil(n / len(present_labels))))
    for label in present_labels:
        label_idx = idx[labels[idx] == label]
        take = min(per_group, len(label_idx))
        selected.extend(rng.choice(label_idx, size=take, replace=False))
    selected = np.asarray(selected, dtype=int)
    if len(selected) > n:
        selected = rng.choice(selected, size=n, replace=False)
    return np.sort(selected)


train_idx = stratified_indices(
    paired_mask,
    labels=celltypes,
    n=DEBUG_TRAIN_CELLS if DEBUG else None,
    random_state=0,
)
test_idx = stratified_indices(
    ~paired_mask,
    labels=celltypes,
    n=DEBUG_TEST_CELLS if DEBUG else None,
    random_state=1,
)

train_mdata = mu.MuData(
    {
        RNA_MOD: mdata.mod[RNA_MOD][train_idx, gene_names].copy(),
        ADT_MOD: mdata.mod[ADT_MOD][train_idx].copy(),
    }
)

adata_rna = mdata.mod[RNA_MOD][test_idx, gene_names].copy()
adata_adt = mdata.mod[ADT_MOD][test_idx].copy()
[12]:
print("training bridge data:")
print(train_mdata)
print("pseudo-unpaired RNA:", adata_rna.shape)
print("pseudo-unpaired ADT:", adata_adt.shape)
print(f"RNA features kept for A: {len(gene_names)}")
print(f"ADT proteins kept for A: {adata_adt.n_vars}")
training bridge data:
MuData object with n_obs × n_vars = 8783 × 4165
  2 modalities
    rna:        8783 × 4031
      obs:      'batch', 'Site', 'DonorNumber', 'celltype', 'fine_celltype', 'n_genes'
      var:      'feature_types', 'gene_id', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'variability_score', 'mean', 'std'
      uns:      'hvg', 'log1p'
      obsm:     'X_pca', 'prior_data'
      layers:   'counts', 'scaled'
    adt:        8783 × 134
      obs:      'batch', 'Site', 'DonorNumber', 'celltype', 'fine_celltype'
      var:      'feature_types', 'gene_id', 'variability_score', 'mean', 'std'
      obsm:     'X_pca', 'prior_data'
      layers:   'counts', 'scaled'
pseudo-unpaired RNA: (2183, 4031)
pseudo-unpaired ADT: (2183, 134)
RNA features kept for A: 4031
ADT proteins kept for A: 134

Fit Champollion

Fitting learns the matrix A from the bridge cells. Because this tutorial uses direct scaled RNA and ADT features, rows and columns of A can be read as gene-protein association scores.

[13]:
model = Champollion(
    epsilon=1.0,
    gamma=0.01,
    lambda_prior=20.0,
    use_keops=False,
    device="auto",
    random_state=0,
    max_iter=2 if DEBUG else 300_000,
    learning_rate=1e-3,
    sinkhorn_tol=1e-3,
    log_every=1 if DEBUG else 40,
    verbose=True,
)
model.fit(
    train_mdata,
    modality_1=RNA_MOD,
    modality_2=ADT_MOD,
    x_1_rep="layers/scaled",
    x_2_rep="layers/scaled",
)

A_df = model.A_dataframe()
print("A shape:", A_df.shape)
print("A device:", model.A_.device)
pd.DataFrame(
    {key: pd.Series(value) for key, value in model.training_history_.items()}
).tail()
  0%|          | 5/300000 [00:00<12:34:06,  6.63it/s]
train_total_loss: 83.64341735839844
train_obj_loss: 29.6660099029541
train_reg_loss: 53.97740936279297
train_trace_term: 0.28358766436576843
train_err_1: 1.7106705904006958
train_plan_mass: 1.0000009536743164
  0%|          | 44/300000 [00:02<2:43:33, 30.56it/s]
train_total_loss: 36.16069793701172
train_obj_loss: -0.7098903656005859
train_reg_loss: 36.87059020996094
train_trace_term: -17.491703033447266
train_err_1: 0.6642380356788635
train_plan_mass: 1.0000001192092896
  0%|          | 85/300000 [00:03<2:27:05, 33.98it/s]
train_total_loss: 18.979022979736328
train_obj_loss: -5.561318397521973
train_reg_loss: 24.540340423583984
train_trace_term: -17.214675903320312
train_err_1: 0.3991648554801941
train_plan_mass: 1.0000001192092896
  0%|          | 125/300000 [00:04<2:37:45, 31.68it/s]
train_total_loss: 9.871883392333984
train_obj_loss: -6.612461090087891
train_reg_loss: 16.484344482421875
train_trace_term: -15.324743270874023
train_err_1: 0.30400556325912476
train_plan_mass: 1.0000001192092896
  0%|          | 165/300000 [00:05<2:28:13, 33.72it/s]
train_total_loss: 4.574556827545166
train_obj_loss: -6.823263645172119
train_reg_loss: 11.397820472717285
train_trace_term: -13.797103881835938
train_err_1: 0.2642056345939636
train_plan_mass: 1.0000001192092896
  0%|          | 206/300000 [00:06<2:37:26, 31.74it/s]
train_total_loss: 1.4049677848815918
train_obj_loss: -6.842285633087158
train_reg_loss: 8.24725341796875
train_trace_term: -12.920703887939453
train_err_1: 0.2460447996854782
train_plan_mass: 1.0
  0%|          | 247/300000 [00:08<2:28:23, 33.67it/s]
train_total_loss: -0.5245165824890137
train_obj_loss: -6.82151985168457
train_reg_loss: 6.297003269195557
train_trace_term: -12.495957374572754
train_err_1: 0.23713943362236023
train_plan_mass: 1.0000001192092896
  0%|          | 287/300000 [00:09<2:36:44, 31.87it/s]
train_total_loss: -1.7224383354187012
train_obj_loss: -6.801453590393066
train_reg_loss: 5.079015254974365
train_trace_term: -12.326879501342773
train_err_1: 0.23193304240703583
train_plan_mass: 1.0
  0%|          | 327/300000 [00:10<2:29:42, 33.36it/s]
train_total_loss: -2.4827170372009277
train_obj_loss: -6.790687084197998
train_reg_loss: 4.30797004699707
train_trace_term: -12.290518760681152
train_err_1: 0.2276957631111145
train_plan_mass: 1.0
  0%|          | 367/300000 [00:11<2:24:43, 34.51it/s]
train_total_loss: -2.976349115371704
train_obj_loss: -6.7884039878845215
train_reg_loss: 3.8120548725128174
train_trace_term: -12.319276809692383
train_err_1: 0.2235061526298523
train_plan_mass: 1.0
  0%|          | 408/300000 [00:12<2:31:25, 32.97it/s]
train_total_loss: -3.30483341217041
train_obj_loss: -6.7929229736328125
train_reg_loss: 3.4880895614624023
train_trace_term: -12.380534172058105
train_err_1: 0.21885468065738678
train_plan_mass: 1.0000001192092896
  0%|          | 446/300000 [00:13<2:06:14, 39.55it/s]
train_total_loss: -3.5289196968078613
train_obj_loss: -6.8017964363098145
train_reg_loss: 3.272876739501953
train_trace_term: -12.458037376403809
train_err_1: 0.21375218033790588
train_plan_mass: 1.0000001192092896
  0%|          | 485/300000 [00:14<2:20:17, 35.58it/s]
train_total_loss: -3.685533046722412
train_obj_loss: -6.81279993057251
train_reg_loss: 3.1272668838500977
train_trace_term: -12.539905548095703
train_err_1: 0.2082129716873169
train_plan_mass: 1.0
  0%|          | 525/300000 [00:16<2:24:07, 34.63it/s]
train_total_loss: -3.7976062297821045
train_obj_loss: -6.824493885040283
train_reg_loss: 3.0268876552581787
train_trace_term: -12.623461723327637
train_err_1: 0.20267464220523834
train_plan_mass: 1.0000001192092896
  0%|          | 567/300000 [00:17<2:25:45, 34.24it/s]
train_total_loss: -3.8797683715820312
train_obj_loss: -6.836313724517822
train_reg_loss: 2.956545352935791
train_trace_term: -12.703912734985352
train_err_1: 0.1970621943473816
train_plan_mass: 1.0000001192092896
  0%|          | 605/300000 [00:18<2:15:58, 36.70it/s]
train_total_loss: -3.941526412963867
train_obj_loss: -6.847626209259033
train_reg_loss: 2.906099796295166
train_trace_term: -12.780219078063965
train_err_1: 0.19141393899917603
train_plan_mass: 1.0000001192092896
  0%|          | 646/300000 [00:19<2:23:44, 34.71it/s]
train_total_loss: -3.9890594482421875
train_obj_loss: -6.858532905578613
train_reg_loss: 2.869473457336426
train_trace_term: -12.85342788696289
train_err_1: 0.18591246008872986
train_plan_mass: 1.0000001192092896
  0%|          | 686/300000 [00:20<2:22:56, 34.90it/s]
train_total_loss: -4.026348114013672
train_obj_loss: -6.868851184844971
train_reg_loss: 2.842503070831299
train_trace_term: -12.922861099243164
train_err_1: 0.1805248111486435
train_plan_mass: 1.0
  0%|          | 726/300000 [00:21<2:26:29, 34.05it/s]
train_total_loss: -4.056063652038574
train_obj_loss: -6.878418922424316
train_reg_loss: 2.822355270385742
train_trace_term: -12.987800598144531
train_err_1: 0.17524738609790802
train_plan_mass: 1.0
  0%|          | 767/300000 [00:22<2:28:45, 33.52it/s]
train_total_loss: -4.080065727233887
train_obj_loss: -6.8872270584106445
train_reg_loss: 2.807161569595337
train_trace_term: -13.04813003540039
train_err_1: 0.1700839400291443
train_plan_mass: 1.0
  0%|          | 807/300000 [00:24<2:26:41, 33.99it/s]
train_total_loss: -4.099761962890625
train_obj_loss: -6.895297527313232
train_reg_loss: 2.7955358028411865
train_trace_term: -13.104562759399414
train_err_1: 0.16499386727809906
train_plan_mass: 1.0000001192092896
  0%|          | 848/300000 [00:25<2:24:46, 34.44it/s]
train_total_loss: -4.116208553314209
train_obj_loss: -6.902707099914551
train_reg_loss: 2.786498546600342
train_trace_term: -13.157866477966309
train_err_1: 0.15997447073459625
train_plan_mass: 1.0
  0%|          | 888/300000 [00:26<2:30:15, 33.18it/s]
train_total_loss: -4.130165100097656
train_obj_loss: -6.9096221923828125
train_reg_loss: 2.7794573307037354
train_trace_term: -13.208658218383789
train_err_1: 0.15507450699806213
train_plan_mass: 1.0000001192092896
  0%|          | 924/300000 [00:27<2:38:22, 31.47it/s]
train_total_loss: -4.142151832580566
train_obj_loss: -6.9160614013671875
train_reg_loss: 2.773909330368042
train_trace_term: -13.257318496704102
train_err_1: 0.15032948553562164
train_plan_mass: 1.0000001192092896
  0%|          | 965/300000 [00:28<2:25:57, 34.14it/s]
train_total_loss: -4.152547836303711
train_obj_loss: -6.92201566696167
train_reg_loss: 2.769468069076538
train_trace_term: -13.304132461547852
train_err_1: 0.14574739336967468
train_plan_mass: 1.0
  0%|          | 1005/300000 [00:29<2:31:00, 33.00it/s]
train_total_loss: -4.161661624908447
train_obj_loss: -6.927513599395752
train_reg_loss: 2.7658519744873047
train_trace_term: -13.349334716796875
train_err_1: 0.14130371809005737
train_plan_mass: 1.0000001192092896
  0%|          | 1045/300000 [00:30<2:17:25, 36.26it/s]
train_total_loss: -4.169730186462402
train_obj_loss: -6.932551860809326
train_reg_loss: 2.762821674346924
train_trace_term: -13.393157005310059
train_err_1: 0.13702324032783508
train_plan_mass: 1.0000001192092896
  0%|          | 1086/300000 [00:32<2:17:13, 36.31it/s]
train_total_loss: -4.1769514083862305
train_obj_loss: -6.937219619750977
train_reg_loss: 2.760267972946167
train_trace_term: -13.435593605041504
train_err_1: 0.13294047117233276
train_plan_mass: 1.0000001192092896
  0%|          | 1126/300000 [00:33<2:31:18, 32.92it/s]
train_total_loss: -4.183474540710449
train_obj_loss: -6.941565990447998
train_reg_loss: 2.758091449737549
train_trace_term: -13.476727485656738
train_err_1: 0.12907841801643372
train_plan_mass: 1.000000238418579
  0%|          | 1166/300000 [00:34<2:25:48, 34.16it/s]
train_total_loss: -4.189408779144287
train_obj_loss: -6.94560432434082
train_reg_loss: 2.756195545196533
train_trace_term: -13.516610145568848
train_err_1: 0.12539422512054443
train_plan_mass: 1.0000001192092896
  0%|          | 1206/300000 [00:35<2:27:22, 33.79it/s]
train_total_loss: -4.194843292236328
train_obj_loss: -6.949387073516846
train_reg_loss: 2.7545440196990967
train_trace_term: -13.555426597595215
train_err_1: 0.12186814099550247
train_plan_mass: 1.0
  0%|          | 1246/300000 [00:36<2:27:00, 33.87it/s]
train_total_loss: -4.1998443603515625
train_obj_loss: -6.952945232391357
train_reg_loss: 2.753101110458374
train_trace_term: -13.593399047851562
train_err_1: 0.11848947405815125
train_plan_mass: 1.0000001192092896
  0%|          | 1286/300000 [00:37<2:26:52, 33.90it/s]
train_total_loss: -4.204472541809082
train_obj_loss: -6.956295967102051
train_reg_loss: 2.751823663711548
train_trace_term: -13.6306791305542
train_err_1: 0.11524634808301926
train_plan_mass: 1.0000001192092896
  0%|          | 1326/300000 [00:39<2:29:51, 33.22it/s]
train_total_loss: -4.208771228790283
train_obj_loss: -6.9594268798828125
train_reg_loss: 2.7506556510925293
train_trace_term: -13.667181968688965
train_err_1: 0.11212511360645294
train_plan_mass: 1.000000238418579
  0%|          | 1366/300000 [00:40<2:36:20, 31.83it/s]
train_total_loss: -4.2127838134765625
train_obj_loss: -6.962375164031982
train_reg_loss: 2.749591588973999
train_trace_term: -13.703152656555176
train_err_1: 0.10912556201219559
train_plan_mass: 1.0
  0%|          | 1406/300000 [00:41<2:30:52, 32.98it/s]
train_total_loss: -4.216546535491943
train_obj_loss: -6.965155601501465
train_reg_loss: 2.7486090660095215
train_trace_term: -13.738567352294922
train_err_1: 0.10624394565820694
train_plan_mass: 1.0000001192092896
  0%|          | 1446/300000 [00:42<2:26:23, 33.99it/s]
train_total_loss: -4.220088005065918
train_obj_loss: -6.967787265777588
train_reg_loss: 2.747699499130249
train_trace_term: -13.77342700958252
train_err_1: 0.10348615795373917
train_plan_mass: 1.0000001192092896
  0%|          | 1486/300000 [00:43<2:31:21, 32.87it/s]
train_total_loss: -4.22342586517334
train_obj_loss: -6.970270156860352
train_reg_loss: 2.746844530105591
train_trace_term: -13.807872772216797
train_err_1: 0.10080917179584503
train_plan_mass: 1.000000238418579
  1%|          | 1526/300000 [00:44<2:28:31, 33.49it/s]
train_total_loss: -4.22658634185791
train_obj_loss: -6.972628593444824
train_reg_loss: 2.746042013168335
train_trace_term: -13.841869354248047
train_err_1: 0.09823441505432129
train_plan_mass: 1.0000001192092896
  1%|          | 1566/300000 [00:46<2:28:27, 33.50it/s]
train_total_loss: -4.229583740234375
train_obj_loss: -6.974879741668701
train_reg_loss: 2.745296001434326
train_trace_term: -13.875395774841309
train_err_1: 0.09575893729925156
train_plan_mass: 1.000000238418579
  1%|          | 1607/300000 [00:47<2:20:04, 35.50it/s]
train_total_loss: -4.232429504394531
train_obj_loss: -6.977034091949463
train_reg_loss: 2.7446043491363525
train_trace_term: -13.908448219299316
train_err_1: 0.09337112307548523
train_plan_mass: 1.000000238418579
  1%|          | 1647/300000 [00:48<2:25:22, 34.20it/s]
train_total_loss: -4.2351393699646
train_obj_loss: -6.979101181030273
train_reg_loss: 2.743961811065674
train_trace_term: -13.941034317016602
train_err_1: 0.0910739153623581
train_plan_mass: 1.0000001192092896
  1%|          | 1684/300000 [00:49<2:30:10, 33.11it/s]
train_total_loss: -4.237720966339111
train_obj_loss: -6.981077194213867
train_reg_loss: 2.743356227874756
train_trace_term: -13.973125457763672
train_err_1: 0.08885658532381058
train_plan_mass: 1.0000001192092896
  1%|          | 1729/300000 [00:50<1:59:04, 41.75it/s]
train_total_loss: -4.24018669128418
train_obj_loss: -6.982977390289307
train_reg_loss: 2.742790460586548
train_trace_term: -14.004743576049805
train_err_1: 0.08672875165939331
train_plan_mass: 1.0000001192092896
  1%|          | 1769/300000 [00:51<1:54:46, 43.31it/s]
train_total_loss: -4.242544174194336
train_obj_loss: -6.98479700088501
train_reg_loss: 2.742252826690674
train_trace_term: -14.035887718200684
train_err_1: 0.08468294888734818
train_plan_mass: 1.000000238418579
  1%|          | 1809/300000 [00:52<1:54:31, 43.40it/s]
train_total_loss: -4.244797706604004
train_obj_loss: -6.986522197723389
train_reg_loss: 2.741724729537964
train_trace_term: -14.066533088684082
train_err_1: 0.08270114660263062
train_plan_mass: 1.000000238418579
  1%|          | 1849/300000 [00:53<1:54:27, 43.41it/s]
train_total_loss: -4.246955394744873
train_obj_loss: -6.988179683685303
train_reg_loss: 2.7412242889404297
train_trace_term: -14.096769332885742
train_err_1: 0.08077839761972427
train_plan_mass: 1.0000001192092896
  1%|          | 1889/300000 [00:54<1:54:24, 43.43it/s]
train_total_loss: -4.249024391174316
train_obj_loss: -6.989774227142334
train_reg_loss: 2.7407500743865967
train_trace_term: -14.126627922058105
train_err_1: 0.07891368865966797
train_plan_mass: 1.000000238418579
  1%|          | 1929/300000 [00:55<1:54:30, 43.38it/s]
train_total_loss: -4.251009941101074
train_obj_loss: -6.991312026977539
train_reg_loss: 2.740302324295044
train_trace_term: -14.156139373779297
train_err_1: 0.07709992676973343
train_plan_mass: 1.0000001192092896
  1%|          | 1969/300000 [00:56<1:54:27, 43.39it/s]
train_total_loss: -4.252920150756836
train_obj_loss: -6.992798328399658
train_reg_loss: 2.739877939224243
train_trace_term: -14.18531608581543
train_err_1: 0.0753369927406311
train_plan_mass: 1.000000238418579
  1%|          | 2009/300000 [00:56<1:54:23, 43.41it/s]
train_total_loss: -4.254758834838867
train_obj_loss: -6.994231700897217
train_reg_loss: 2.7394726276397705
train_trace_term: -14.21415901184082
train_err_1: 0.07362449169158936
train_plan_mass: 1.0000001192092896
  1%|          | 2049/300000 [00:57<1:55:29, 43.00it/s]
train_total_loss: -4.256534576416016
train_obj_loss: -6.995620250701904
train_reg_loss: 2.7390856742858887
train_trace_term: -14.242680549621582
train_err_1: 0.07195869833230972
train_plan_mass: 1.0000001192092896
  1%|          | 2089/300000 [00:58<1:54:29, 43.37it/s]
train_total_loss: -4.258247375488281
train_obj_loss: -6.99696159362793
train_reg_loss: 2.7387139797210693
train_trace_term: -14.270873069763184
train_err_1: 0.07034049928188324
train_plan_mass: 1.0000001192092896
  1%|          | 2129/300000 [00:59<1:54:21, 43.41it/s]
train_total_loss: -4.259904861450195
train_obj_loss: -6.998263835906982
train_reg_loss: 2.738358974456787
train_trace_term: -14.298710823059082
train_err_1: 0.0687701627612114
train_plan_mass: 1.0000001192092896
  1%|          | 2169/300000 [01:00<1:54:20, 43.41it/s]
train_total_loss: -4.261505126953125
train_obj_loss: -6.999521255493164
train_reg_loss: 2.738016366958618
train_trace_term: -14.326194763183594
train_err_1: 0.06724654138088226
train_plan_mass: 1.0000001192092896
  1%|          | 2209/300000 [01:01<1:54:22, 43.39it/s]
train_total_loss: -4.263052940368652
train_obj_loss: -7.000739574432373
train_reg_loss: 2.7376866340637207
train_trace_term: -14.353279113769531
train_err_1: 0.06576843559741974
train_plan_mass: 1.000000238418579
  1%|          | 2249/300000 [01:02<1:54:25, 43.37it/s]
train_total_loss: -4.264549255371094
train_obj_loss: -7.001920223236084
train_reg_loss: 2.7373709678649902
train_trace_term: -14.37997055053711
train_err_1: 0.064328134059906
train_plan_mass: 1.0000001192092896
  1%|          | 2289/300000 [01:03<1:54:20, 43.40it/s]
train_total_loss: -4.26599645614624
train_obj_loss: -7.003065586090088
train_reg_loss: 2.7370691299438477
train_trace_term: -14.406299591064453
train_err_1: 0.0629270076751709
train_plan_mass: 1.0000001192092896
  1%|          | 2329/300000 [01:04<1:54:19, 43.40it/s]
train_total_loss: -4.267401695251465
train_obj_loss: -7.004178047180176
train_reg_loss: 2.73677659034729
train_trace_term: -14.432266235351562
train_err_1: 0.06156628578901291
train_plan_mass: 1.000000238418579
  1%|          | 2369/300000 [01:05<1:54:17, 43.40it/s]
train_total_loss: -4.268760681152344
train_obj_loss: -7.005258560180664
train_reg_loss: 2.736497640609741
train_trace_term: -14.457893371582031
train_err_1: 0.06024269387125969
train_plan_mass: 1.0000001192092896
  1%|          | 2409/300000 [01:06<1:54:17, 43.39it/s]
train_total_loss: -4.270083904266357
train_obj_loss: -7.006318092346191
train_reg_loss: 2.736234188079834
train_trace_term: -14.483238220214844
train_err_1: 0.05895078182220459
train_plan_mass: 1.0000001192092896
  1%|          | 2449/300000 [01:07<1:54:21, 43.37it/s]
train_total_loss: -4.271369934082031
train_obj_loss: -7.007356643676758
train_reg_loss: 2.7359864711761475
train_trace_term: -14.508309364318848
train_err_1: 0.05768932029604912
train_plan_mass: 1.0000001192092896
  1%|          | 2489/300000 [01:07<1:55:29, 42.93it/s]
train_total_loss: -4.27262020111084
train_obj_loss: -7.008367538452148
train_reg_loss: 2.7357470989227295
train_trace_term: -14.533164978027344
train_err_1: 0.056464292109012604
train_plan_mass: 1.0000001192092896
  1%|          | 2529/300000 [01:08<1:54:22, 43.34it/s]
train_total_loss: -4.273836135864258
train_obj_loss: -7.009363174438477
train_reg_loss: 2.7355270385742188
train_trace_term: -14.557787895202637
train_err_1: 0.05527406185865402
train_plan_mass: 1.000000238418579
  1%|          | 2569/300000 [01:09<1:54:13, 43.40it/s]
train_total_loss: -4.275020599365234
train_obj_loss: -7.010339736938477
train_reg_loss: 2.735319137573242
train_trace_term: -14.582184791564941
train_err_1: 0.05411553755402565
train_plan_mass: 1.000000238418579
  1%|          | 2609/300000 [01:10<1:54:13, 43.39it/s]
train_total_loss: -4.276175498962402
train_obj_loss: -7.011293888092041
train_reg_loss: 2.7351181507110596
train_trace_term: -14.606355667114258
train_err_1: 0.05299283564090729
train_plan_mass: 1.0000001192092896
  1%|          | 2649/300000 [01:11<1:54:10, 43.41it/s]
train_total_loss: -4.2772979736328125
train_obj_loss: -7.012234687805176
train_reg_loss: 2.734936475753784
train_trace_term: -14.630308151245117
train_err_1: 0.05189149081707001
train_plan_mass: 1.000000238418579
  1%|          | 2689/300000 [01:12<1:54:09, 43.41it/s]
train_total_loss: -4.2783918380737305
train_obj_loss: -7.013157844543457
train_reg_loss: 2.7347662448883057
train_trace_term: -14.654059410095215
train_err_1: 0.050820138305425644
train_plan_mass: 1.0
  1%|          | 2729/300000 [01:13<1:54:10, 43.39it/s]
train_total_loss: -4.279457092285156
train_obj_loss: -7.014062881469727
train_reg_loss: 2.734605550765991
train_trace_term: -14.677589416503906
train_err_1: 0.04977388307452202
train_plan_mass: 1.000000238418579
  1%|          | 2769/300000 [01:14<1:54:06, 43.41it/s]
train_total_loss: -4.280496597290039
train_obj_loss: -7.014946937561035
train_reg_loss: 2.734450578689575
train_trace_term: -14.700873374938965
train_err_1: 0.048758365213871
train_plan_mass: 1.0000001192092896
  1%|          | 2809/300000 [01:15<1:54:03, 43.43it/s]
train_total_loss: -4.281508445739746
train_obj_loss: -7.0158162117004395
train_reg_loss: 2.7343080043792725
train_trace_term: -14.723953247070312
train_err_1: 0.04776805266737938
train_plan_mass: 1.000000238418579
  1%|          | 2849/300000 [01:16<1:54:05, 43.41it/s]
train_total_loss: -4.282496452331543
train_obj_loss: -7.016671657562256
train_reg_loss: 2.734175443649292
train_trace_term: -14.746818542480469
train_err_1: 0.046803027391433716
train_plan_mass: 1.0000001192092896
  1%|          | 2889/300000 [01:17<1:54:05, 43.40it/s]
train_total_loss: -4.283458709716797
train_obj_loss: -7.017510414123535
train_reg_loss: 2.734051465988159
train_trace_term: -14.769463539123535
train_err_1: 0.045861463993787766
train_plan_mass: 1.0000001192092896
  1%|          | 2929/300000 [01:18<1:55:08, 43.00it/s]
train_total_loss: -4.284396648406982
train_obj_loss: -7.018333435058594
train_reg_loss: 2.7339367866516113
train_trace_term: -14.791864395141602
train_err_1: 0.04494393616914749
train_plan_mass: 1.0000001192092896
  1%|          | 2969/300000 [01:18<1:54:16, 43.32it/s]
train_total_loss: -4.285312652587891
train_obj_loss: -7.019144058227539
train_reg_loss: 2.7338311672210693
train_trace_term: -14.814000129699707
train_err_1: 0.04404660686850548
train_plan_mass: 1.000000238418579
  1%|          | 3009/300000 [01:19<1:54:05, 43.38it/s]
train_total_loss: -4.286206245422363
train_obj_loss: -7.019941806793213
train_reg_loss: 2.7337357997894287
train_trace_term: -14.835853576660156
train_err_1: 0.04317041486501694
train_plan_mass: 1.0000001192092896
  1%|          | 3049/300000 [01:20<1:53:57, 43.43it/s]
train_total_loss: -4.287075519561768
train_obj_loss: -7.020715236663818
train_reg_loss: 2.733639717102051
train_trace_term: -14.857487678527832
train_err_1: 0.0423128716647625
train_plan_mass: 1.000000238418579
  1%|          | 3089/300000 [01:21<1:53:55, 43.43it/s]
train_total_loss: -4.287924289703369
train_obj_loss: -7.021472930908203
train_reg_loss: 2.733548641204834
train_trace_term: -14.878820419311523
train_err_1: 0.041477080434560776
train_plan_mass: 1.0000001192092896
  1%|          | 3129/300000 [01:22<1:54:01, 43.39it/s]
train_total_loss: -4.288752555847168
train_obj_loss: -7.022213459014893
train_reg_loss: 2.7334611415863037
train_trace_term: -14.899877548217773
train_err_1: 0.0406612753868103
train_plan_mass: 1.0000001192092896
  1%|          | 3169/300000 [01:23<1:54:00, 43.39it/s]
train_total_loss: -4.2895612716674805
train_obj_loss: -7.022946357727051
train_reg_loss: 2.733384847640991
train_trace_term: -14.920721054077148
train_err_1: 0.03986552357673645
train_plan_mass: 1.000000238418579
  1%|          | 3209/300000 [01:24<1:54:00, 43.39it/s]
train_total_loss: -4.290350914001465
train_obj_loss: -7.023663520812988
train_reg_loss: 2.7333128452301025
train_trace_term: -14.941354751586914
train_err_1: 0.03908606618642807
train_plan_mass: 1.0000001192092896
  1%|          | 3249/300000 [01:25<1:53:54, 43.42it/s]
train_total_loss: -4.2911224365234375
train_obj_loss: -7.024359703063965
train_reg_loss: 2.7332372665405273
train_trace_term: -14.961776733398438
train_err_1: 0.03832053393125534
train_plan_mass: 1.0000001192092896
  1%|          | 3289/300000 [01:26<1:53:52, 43.43it/s]
train_total_loss: -4.291875839233398
train_obj_loss: -7.025043487548828
train_reg_loss: 2.7331676483154297
train_trace_term: -14.982016563415527
train_err_1: 0.037571556866168976
train_plan_mass: 1.0000001192092896
  1%|          | 3329/300000 [01:27<1:53:53, 43.41it/s]
train_total_loss: -4.2926130294799805
train_obj_loss: -7.025719165802002
train_reg_loss: 2.7331061363220215
train_trace_term: -15.002043724060059
train_err_1: 0.03684087470173836
train_plan_mass: 1.0000001192092896
  1%|          | 3369/300000 [01:28<1:55:02, 42.97it/s]
train_total_loss: -4.293334007263184
train_obj_loss: -7.0263872146606445
train_reg_loss: 2.73305344581604
train_trace_term: -15.021830558776855
train_err_1: 0.03612760454416275
train_plan_mass: 1.000000238418579
  1%|          | 3409/300000 [01:29<1:53:54, 43.39it/s]
train_total_loss: -4.294039249420166
train_obj_loss: -7.027040481567383
train_reg_loss: 2.733001232147217
train_trace_term: -15.041414260864258
train_err_1: 0.03542765974998474
train_plan_mass: 1.0000001192092896
  1%|          | 3449/300000 [01:29<1:53:51, 43.41it/s]
train_total_loss: -4.294726371765137
train_obj_loss: -7.027679443359375
train_reg_loss: 2.732952833175659
train_trace_term: -15.060836791992188
train_err_1: 0.034739986062049866
train_plan_mass: 1.0000001192092896
  1%|          | 3489/300000 [01:30<1:53:47, 43.43it/s]
train_total_loss: -4.295400619506836
train_obj_loss: -7.028306007385254
train_reg_loss: 2.732905626296997
train_trace_term: -15.080081939697266
train_err_1: 0.03406800702214241
train_plan_mass: 1.0000001192092896
  1%|          | 3529/300000 [01:31<1:53:48, 43.42it/s]
train_total_loss: -4.296059608459473
train_obj_loss: -7.028921127319336
train_reg_loss: 2.7328615188598633
train_trace_term: -15.099115371704102
train_err_1: 0.03341430425643921
train_plan_mass: 1.0000001192092896
  1%|          | 3569/300000 [01:32<1:53:51, 43.39it/s]
train_total_loss: -4.296706676483154
train_obj_loss: -7.0295305252075195
train_reg_loss: 2.7328238487243652
train_trace_term: -15.11791706085205
train_err_1: 0.03277728334069252
train_plan_mass: 1.0000001192092896
  1%|          | 3609/300000 [01:33<1:53:47, 43.41it/s]
train_total_loss: -4.29733943939209
train_obj_loss: -7.030135154724121
train_reg_loss: 2.7327959537506104
train_trace_term: -15.136527061462402
train_err_1: 0.032150931656360626
train_plan_mass: 1.0000001192092896
  1%|          | 3649/300000 [01:34<1:57:24, 42.07it/s]
train_total_loss: -4.29796028137207
train_obj_loss: -7.030732154846191
train_reg_loss: 2.7327721118927
train_trace_term: -15.154977798461914
train_err_1: 0.031532254070043564
train_plan_mass: 1.000000238418579
  1%|          | 3689/300000 [01:35<1:53:56, 43.35it/s]
train_total_loss: -4.298567771911621
train_obj_loss: -7.031323432922363
train_reg_loss: 2.7327558994293213
train_trace_term: -15.173255920410156
train_err_1: 0.03093232586979866
train_plan_mass: 1.0000001192092896
  1%|          | 3729/300000 [01:36<1:57:19, 42.09it/s]
train_total_loss: -4.299162864685059
train_obj_loss: -7.031904220581055
train_reg_loss: 2.732741117477417
train_trace_term: -15.19138240814209
train_err_1: 0.030352063477039337
train_plan_mass: 1.000000238418579
  1%|▏         | 3769/300000 [01:37<1:53:59, 43.31it/s]
train_total_loss: -4.299745559692383
train_obj_loss: -7.032473564147949
train_reg_loss: 2.7327280044555664
train_trace_term: -15.209346771240234
train_err_1: 0.02977524697780609
train_plan_mass: 1.000000238418579
  1%|▏         | 3809/300000 [01:38<1:53:44, 43.40it/s]
train_total_loss: -4.30031681060791
train_obj_loss: -7.033037185668945
train_reg_loss: 2.732720375061035
train_trace_term: -15.22720718383789
train_err_1: 0.029205389320850372
train_plan_mass: 1.0000001192092896
  1%|▏         | 3849/300000 [01:39<1:53:53, 43.34it/s]
train_total_loss: -4.300875663757324
train_obj_loss: -7.033592224121094
train_reg_loss: 2.7327165603637695
train_trace_term: -15.244948387145996
train_err_1: 0.028645774349570274
train_plan_mass: 1.000000238418579
  1%|▏         | 3889/300000 [01:40<1:53:37, 43.43it/s]
train_total_loss: -4.301421165466309
train_obj_loss: -7.034137725830078
train_reg_loss: 2.7327163219451904
train_trace_term: -15.262567520141602
train_err_1: 0.02809854969382286
train_plan_mass: 1.000000238418579
  1%|▏         | 3929/300000 [01:40<1:53:34, 43.45it/s]
train_total_loss: -4.301955699920654
train_obj_loss: -7.034676551818848
train_reg_loss: 2.7327208518981934
train_trace_term: -15.280077934265137
train_err_1: 0.027559522539377213
train_plan_mass: 1.000000238418579
  1%|▏         | 3969/300000 [01:41<1:53:28, 43.48it/s]
train_total_loss: -4.3024797439575195
train_obj_loss: -7.035207748413086
train_reg_loss: 2.7327280044555664
train_trace_term: -15.297479629516602
train_err_1: 0.027030151337385178
train_plan_mass: 1.000000238418579
  1%|▏         | 4009/300000 [01:42<1:53:36, 43.43it/s]
train_total_loss: -4.302992820739746
train_obj_loss: -7.035731315612793
train_reg_loss: 2.732738494873047
train_trace_term: -15.314762115478516
train_err_1: 0.02651170641183853
train_plan_mass: 1.000000238418579
  1%|▏         | 4049/300000 [01:43<1:53:37, 43.41it/s]
train_total_loss: -4.303495407104492
train_obj_loss: -7.0362443923950195
train_reg_loss: 2.7327487468719482
train_trace_term: -15.33192253112793
train_err_1: 0.026003990322351456
train_plan_mass: 1.0000001192092896
  1%|▏         | 4089/300000 [01:44<1:56:24, 42.37it/s]
train_total_loss: -4.303987503051758
train_obj_loss: -7.036746025085449
train_reg_loss: 2.7327585220336914
train_trace_term: -15.348968505859375
train_err_1: 0.02550678700208664
train_plan_mass: 1.000000238418579
  1%|▏         | 4129/300000 [01:45<1:53:54, 43.29it/s]
train_total_loss: -4.304470539093018
train_obj_loss: -7.037238121032715
train_reg_loss: 2.7327675819396973
train_trace_term: -15.365896224975586
train_err_1: 0.025024548172950745
train_plan_mass: 1.0000001192092896
  1%|▏         | 4169/300000 [01:46<1:53:39, 43.38it/s]
train_total_loss: -4.3049421310424805
train_obj_loss: -7.037726402282715
train_reg_loss: 2.7327840328216553
train_trace_term: -15.382673263549805
train_err_1: 0.024548351764678955
train_plan_mass: 1.000000238418579
  1%|▏         | 4209/300000 [01:47<1:53:33, 43.41it/s]
train_total_loss: -4.305407524108887
train_obj_loss: -7.038216590881348
train_reg_loss: 2.73280930519104
train_trace_term: -15.399335861206055
train_err_1: 0.02408219315111637
train_plan_mass: 1.000000238418579
  1%|▏         | 4249/300000 [01:48<1:53:32, 43.41it/s]
train_total_loss: -4.305861473083496
train_obj_loss: -7.038689613342285
train_reg_loss: 2.73282790184021
train_trace_term: -15.415943145751953
train_err_1: 0.02362413890659809
train_plan_mass: 1.000000238418579
  1%|▏         | 4289/300000 [01:49<1:53:29, 43.43it/s]
train_total_loss: -4.306306838989258
train_obj_loss: -7.039151191711426
train_reg_loss: 2.732844114303589
train_trace_term: -15.432448387145996
train_err_1: 0.023177605122327805
train_plan_mass: 1.0000001192092896
  1%|▏         | 4329/300000 [01:50<1:53:33, 43.39it/s]
train_total_loss: -4.3067402839660645
train_obj_loss: -7.039609909057617
train_reg_loss: 2.7328696250915527
train_trace_term: -15.448862075805664
train_err_1: 0.022733056917786598
train_plan_mass: 1.0000001192092896
  1%|▏         | 4369/300000 [01:51<1:53:30, 43.41it/s]
train_total_loss: -4.307168960571289
train_obj_loss: -7.040067672729492
train_reg_loss: 2.732898712158203
train_trace_term: -15.465155601501465
train_err_1: 0.022302918136119843
train_plan_mass: 1.000000238418579
  1%|▏         | 4409/300000 [01:51<1:53:29, 43.41it/s]
train_total_loss: -4.30758810043335
train_obj_loss: -7.040523529052734
train_reg_loss: 2.7329354286193848
train_trace_term: -15.481305122375488
train_err_1: 0.02188834175467491
train_plan_mass: 1.000000238418579
  1%|▏         | 4449/300000 [01:52<1:53:33, 43.38it/s]
train_total_loss: -4.308000564575195
train_obj_loss: -7.040964126586914
train_reg_loss: 2.732963800430298
train_trace_term: -15.497383117675781
train_err_1: 0.021476535126566887
train_plan_mass: 1.000000238418579
  1%|▏         | 4489/300000 [01:53<1:53:33, 43.37it/s]
train_total_loss: -4.308403015136719
train_obj_loss: -7.041403770446777
train_reg_loss: 2.7330007553100586
train_trace_term: -15.513338088989258
train_err_1: 0.021075036376714706
train_plan_mass: 1.000000238418579
  2%|▏         | 4529/300000 [01:54<1:53:59, 43.20it/s]
train_total_loss: -4.308798789978027
train_obj_loss: -7.041843414306641
train_reg_loss: 2.733044385910034
train_trace_term: -15.529207229614258
train_err_1: 0.02068094164133072
train_plan_mass: 1.000000238418579
  2%|▏         | 4569/300000 [01:55<1:53:30, 43.38it/s]
train_total_loss: -4.309191703796387
train_obj_loss: -7.04228401184082
train_reg_loss: 2.7330925464630127
train_trace_term: -15.54505729675293
train_err_1: 0.020291995257139206
train_plan_mass: 1.000000238418579
  2%|▏         | 4609/300000 [01:56<1:53:15, 43.47it/s]
train_total_loss: -4.309576034545898
train_obj_loss: -7.042729377746582
train_reg_loss: 2.7331531047821045
train_trace_term: -15.560786247253418
train_err_1: 0.019922513514757156
train_plan_mass: 1.0000001192092896
  2%|▏         | 4649/300000 [01:57<1:53:15, 43.46it/s]
train_total_loss: -4.30994987487793
train_obj_loss: -7.043148040771484
train_reg_loss: 2.7331981658935547
train_trace_term: -15.57640552520752
train_err_1: 0.019550088793039322
train_plan_mass: 1.0000001192092896
  2%|▏         | 4689/300000 [01:58<1:53:16, 43.45it/s]
train_total_loss: -4.3103179931640625
train_obj_loss: -7.0435590744018555
train_reg_loss: 2.733241319656372
train_trace_term: -15.591983795166016
train_err_1: 0.019182337448000908
train_plan_mass: 1.000000238418579
  2%|▏         | 4729/300000 [01:59<1:53:12, 43.47it/s]
train_total_loss: -4.310676574707031
train_obj_loss: -7.043957710266113
train_reg_loss: 2.733281135559082
train_trace_term: -15.60751724243164
train_err_1: 0.01882229745388031
train_plan_mass: 1.0000001192092896
  2%|▏         | 4769/300000 [02:00<1:53:10, 43.48it/s]
train_total_loss: -4.311029434204102
train_obj_loss: -7.044349670410156
train_reg_loss: 2.7333202362060547
train_trace_term: -15.622964859008789
train_err_1: 0.01847468502819538
train_plan_mass: 1.000000238418579
  2%|▏         | 4809/300000 [02:01<1:53:11, 43.46it/s]
train_total_loss: -4.311375141143799
train_obj_loss: -7.044730186462402
train_reg_loss: 2.7333550453186035
train_trace_term: -15.638422012329102
train_err_1: 0.01813659444451332
train_plan_mass: 1.000000238418579
  2%|▏         | 4849/300000 [02:02<1:53:15, 43.43it/s]
train_total_loss: -4.311713218688965
train_obj_loss: -7.045109748840332
train_reg_loss: 2.733396530151367
train_trace_term: -15.653730392456055
train_err_1: 0.017794180661439896
train_plan_mass: 1.0000001192092896
  2%|▏         | 4889/300000 [02:02<1:53:56, 43.17it/s]
train_total_loss: -4.312045097351074
train_obj_loss: -7.045489311218262
train_reg_loss: 2.7334442138671875
train_trace_term: -15.668968200683594
train_err_1: 0.017455559223890305
train_plan_mass: 1.000000238418579
  2%|▏         | 4929/300000 [02:03<1:53:35, 43.30it/s]
train_total_loss: -4.3123674392700195
train_obj_loss: -7.045865058898926
train_reg_loss: 2.7334978580474854
train_trace_term: -15.68416976928711
train_err_1: 0.017120234668254852
train_plan_mass: 1.0000001192092896
  2%|▏         | 4969/300000 [02:04<1:55:20, 42.63it/s]
train_total_loss: -4.312684059143066
train_obj_loss: -7.046238899230957
train_reg_loss: 2.7335546016693115
train_trace_term: -15.699287414550781
train_err_1: 0.01678948476910591
train_plan_mass: 1.0000001192092896
  2%|▏         | 5009/300000 [02:05<1:53:21, 43.37it/s]
train_total_loss: -4.312994003295898
train_obj_loss: -7.046606063842773
train_reg_loss: 2.733612060546875
train_trace_term: -15.714323043823242
train_err_1: 0.01646493747830391
train_plan_mass: 1.000000238418579
  2%|▏         | 5049/300000 [02:06<1:53:30, 43.31it/s]
train_total_loss: -4.313298225402832
train_obj_loss: -7.046964645385742
train_reg_loss: 2.7336666584014893
train_trace_term: -15.729289054870605
train_err_1: 0.016145242378115654
train_plan_mass: 1.0000001192092896
  2%|▏         | 5089/300000 [02:07<1:53:22, 43.36it/s]
train_total_loss: -4.313595294952393
train_obj_loss: -7.047316551208496
train_reg_loss: 2.7337212562561035
train_trace_term: -15.744183540344238
train_err_1: 0.015831144526600838
train_plan_mass: 1.000000238418579
  2%|▏         | 5129/300000 [02:08<1:53:29, 43.30it/s]
train_total_loss: -4.313888072967529
train_obj_loss: -7.047663688659668
train_reg_loss: 2.7337756156921387
train_trace_term: -15.758992195129395
train_err_1: 0.0155264213681221
train_plan_mass: 1.000000238418579
  2%|▏         | 5169/300000 [02:09<1:53:24, 43.33it/s]
train_total_loss: -4.314173698425293
train_obj_loss: -7.048006057739258
train_reg_loss: 2.733832359313965
train_trace_term: -15.773731231689453
train_err_1: 0.015227009542286396
train_plan_mass: 1.0000001192092896
  2%|▏         | 5209/300000 [02:10<1:53:23, 43.33it/s]
train_total_loss: -4.314454078674316
train_obj_loss: -7.04835319519043
train_reg_loss: 2.7338991165161133
train_trace_term: -15.7884521484375
train_err_1: 0.01493141707032919
train_plan_mass: 1.000000238418579
  2%|▏         | 5249/300000 [02:11<1:53:31, 43.27it/s]
train_total_loss: -4.314727306365967
train_obj_loss: -7.048701286315918
train_reg_loss: 2.733973979949951
train_trace_term: -15.80311393737793
train_err_1: 0.014643745496869087
train_plan_mass: 1.000000238418579
  2%|▏         | 5289/300000 [02:12<1:53:27, 43.29it/s]
train_total_loss: -4.314995765686035
train_obj_loss: -7.049040794372559
train_reg_loss: 2.7340447902679443
train_trace_term: -15.817621231079102
train_err_1: 0.014371290802955627
train_plan_mass: 1.000000238418579
  2%|▏         | 5329/300000 [02:13<1:53:24, 43.31it/s]
train_total_loss: -4.3152570724487305
train_obj_loss: -7.049375534057617
train_reg_loss: 2.7341182231903076
train_trace_term: -15.832063674926758
train_err_1: 0.014090470038354397
train_plan_mass: 1.0000001192092896
  2%|▏         | 5369/300000 [02:13<1:53:06, 43.41it/s]
train_total_loss: -4.315515518188477
train_obj_loss: -7.049713134765625
train_reg_loss: 2.7341973781585693
train_trace_term: -15.846356391906738
train_err_1: 0.013840332627296448
train_plan_mass: 1.000000238418579
  2%|▏         | 5409/300000 [02:14<1:52:54, 43.49it/s]
train_total_loss: -4.315770149230957
train_obj_loss: -7.050029754638672
train_reg_loss: 2.734259605407715
train_trace_term: -15.860763549804688
train_err_1: 0.0135550731793046
train_plan_mass: 1.000000238418579
  2%|▏         | 5449/300000 [02:15<1:54:11, 42.99it/s]
train_total_loss: -4.31601619720459
train_obj_loss: -7.050350189208984
train_reg_loss: 2.7343337535858154
train_trace_term: -15.875154495239258
train_err_1: 0.013284646905958652
train_plan_mass: 1.000000238418579
  2%|▏         | 5489/300000 [02:16<1:52:56, 43.46it/s]
train_total_loss: -4.316257953643799
train_obj_loss: -7.050670623779297
train_reg_loss: 2.734412670135498
train_trace_term: -15.889436721801758
train_err_1: 0.013023250736296177
train_plan_mass: 1.0000001192092896
  2%|▏         | 5529/300000 [02:17<1:52:53, 43.47it/s]
train_total_loss: -4.316495418548584
train_obj_loss: -7.050994873046875
train_reg_loss: 2.734499454498291
train_trace_term: -15.90361213684082
train_err_1: 0.01277176570147276
train_plan_mass: 1.000000238418579
  2%|▏         | 5569/300000 [02:18<1:52:56, 43.45it/s]
train_total_loss: -4.3167266845703125
train_obj_loss: -7.051313400268555
train_reg_loss: 2.734586715698242
train_trace_term: -15.917701721191406
train_err_1: 0.012520931661128998
train_plan_mass: 1.0000001192092896
  2%|▏         | 5609/300000 [02:19<1:52:52, 43.47it/s]
train_total_loss: -4.31695556640625
train_obj_loss: -7.051609992980957
train_reg_loss: 2.734654188156128
train_trace_term: -15.931675910949707
train_err_1: 0.012306362390518188
train_plan_mass: 1.0000001192092896
  2%|▏         | 5649/300000 [02:20<1:52:49, 43.48it/s]
train_total_loss: -4.317177772521973
train_obj_loss: -7.05191707611084
train_reg_loss: 2.7347395420074463
train_trace_term: -15.94561767578125
train_err_1: 0.012040314264595509
train_plan_mass: 1.000000238418579
  2%|▏         | 5689/300000 [02:21<1:52:50, 43.47it/s]
train_total_loss: -4.317395210266113
train_obj_loss: -7.052217483520508
train_reg_loss: 2.7348222732543945
train_trace_term: -15.959466934204102
train_err_1: 0.011790631338953972
train_plan_mass: 1.0000001192092896
  2%|▏         | 5729/300000 [02:22<1:53:07, 43.36it/s]
train_total_loss: -4.317605972290039
train_obj_loss: -7.052521705627441
train_reg_loss: 2.7349157333374023
train_trace_term: -15.97313117980957
train_err_1: 0.011552557349205017
train_plan_mass: 1.0000001192092896
  2%|▏         | 5769/300000 [02:23<1:52:49, 43.46it/s]
train_total_loss: -4.317812919616699
train_obj_loss: -7.052806854248047
train_reg_loss: 2.7349941730499268
train_trace_term: -15.986793518066406
train_err_1: 0.011316130869090557
train_plan_mass: 1.0000001192092896
  2%|▏         | 5809/300000 [02:24<1:52:48, 43.47it/s]
train_total_loss: -4.318014144897461
train_obj_loss: -7.053088188171387
train_reg_loss: 2.7350738048553467
train_trace_term: -16.000370025634766
train_err_1: 0.011084762401878834
train_plan_mass: 1.000000238418579
  2%|▏         | 5849/300000 [02:24<1:52:44, 43.48it/s]
train_total_loss: -4.318210601806641
train_obj_loss: -7.053363800048828
train_reg_loss: 2.7351531982421875
train_trace_term: -16.01384925842285
train_err_1: 0.010861633345484734
train_plan_mass: 1.0000001192092896
  2%|▏         | 5889/300000 [02:25<1:52:42, 43.49it/s]
train_total_loss: -4.31840705871582
train_obj_loss: -7.0536394119262695
train_reg_loss: 2.7352325916290283
train_trace_term: -16.02726173400879
train_err_1: 0.010642118752002716
train_plan_mass: 1.0000001192092896
  2%|▏         | 5929/300000 [02:26<1:52:41, 43.49it/s]
train_total_loss: -4.318596839904785
train_obj_loss: -7.053913116455078
train_reg_loss: 2.735316038131714
train_trace_term: -16.040531158447266
train_err_1: 0.01043820008635521
train_plan_mass: 1.0000001192092896
  2%|▏         | 5969/300000 [02:27<1:52:43, 43.47it/s]
train_total_loss: -4.31878137588501
train_obj_loss: -7.054189682006836
train_reg_loss: 2.735408306121826
train_trace_term: -16.0537109375
train_err_1: 0.010234691202640533
train_plan_mass: 1.000000238418579
  2%|▏         | 6009/300000 [02:28<1:52:38, 43.50it/s]
train_total_loss: -4.318965911865234
train_obj_loss: -7.054474830627441
train_reg_loss: 2.735508680343628
train_trace_term: -16.066417694091797
train_err_1: 0.010046628303825855
train_plan_mass: 1.000000238418579
  2%|▏         | 6049/300000 [02:29<1:52:41, 43.47it/s]
train_total_loss: -4.3191447257995605
train_obj_loss: -7.054751396179199
train_reg_loss: 2.7356066703796387
train_trace_term: -16.079425811767578
train_err_1: 0.009822733700275421
train_plan_mass: 1.000000238418579
  2%|▏         | 6089/300000 [02:30<1:52:38, 43.49it/s]
train_total_loss: -4.319318771362305
train_obj_loss: -7.055015563964844
train_reg_loss: 2.73569655418396
train_trace_term: -16.09246826171875
train_err_1: 0.00961484108120203
train_plan_mass: 1.000000238418579
  2%|▏         | 6129/300000 [02:31<1:52:40, 43.47it/s]
train_total_loss: -4.319490432739258
train_obj_loss: -7.05527400970459
train_reg_loss: 2.735783576965332
train_trace_term: -16.1055965423584
train_err_1: 0.009415652602910995
train_plan_mass: 1.0000001192092896
  2%|▏         | 6169/300000 [02:32<1:52:57, 43.36it/s]
train_total_loss: -4.319655418395996
train_obj_loss: -7.055530548095703
train_reg_loss: 2.735875129699707
train_trace_term: -16.118663787841797
train_err_1: 0.009214887395501137
train_plan_mass: 1.000000238418579
  2%|▏         | 6209/300000 [02:33<1:52:41, 43.45it/s]
train_total_loss: -4.319816589355469
train_obj_loss: -7.0557756423950195
train_reg_loss: 2.735959053039551
train_trace_term: -16.131576538085938
train_err_1: 0.009014811366796494
train_plan_mass: 1.000000238418579
  2%|▏         | 6249/300000 [02:34<1:52:37, 43.47it/s]
train_total_loss: -4.319972991943359
train_obj_loss: -7.056021690368652
train_reg_loss: 2.736048936843872
train_trace_term: -16.14439582824707
train_err_1: 0.008817752823233604
train_plan_mass: 1.0000001192092896
  2%|▏         | 6289/300000 [02:35<1:52:34, 43.48it/s]
train_total_loss: -4.320125579833984
train_obj_loss: -7.056264877319336
train_reg_loss: 2.7361395359039307
train_trace_term: -16.15712547302246
train_err_1: 0.008626826107501984
train_plan_mass: 1.000000238418579
  2%|▏         | 6329/300000 [02:35<1:52:30, 43.50it/s]
train_total_loss: -4.320274353027344
train_obj_loss: -7.056509971618652
train_reg_loss: 2.7362356185913086
train_trace_term: -16.16977310180664
train_err_1: 0.008437538519501686
train_plan_mass: 1.0000001192092896
  2%|▏         | 6369/300000 [02:36<1:52:30, 43.50it/s]
train_total_loss: -4.320422172546387
train_obj_loss: -7.056756973266602
train_reg_loss: 2.736334800720215
train_trace_term: -16.182281494140625
train_err_1: 0.008268874138593674
train_plan_mass: 1.000000238418579
  2%|▏         | 6409/300000 [02:37<1:52:29, 43.50it/s]
train_total_loss: -4.320562362670898
train_obj_loss: -7.0569915771484375
train_reg_loss: 2.736429452896118
train_trace_term: -16.19474220275879
train_err_1: 0.008079702034592628
train_plan_mass: 1.0000001192092896
  2%|▏         | 6449/300000 [02:38<1:52:28, 43.50it/s]
train_total_loss: -4.320703029632568
train_obj_loss: -7.057225227355957
train_reg_loss: 2.7365221977233887
train_trace_term: -16.207128524780273
train_err_1: 0.007905106991529465
train_plan_mass: 1.0000001192092896
  2%|▏         | 6489/300000 [02:39<1:52:29, 43.48it/s]
train_total_loss: -4.320837497711182
train_obj_loss: -7.05745792388916
train_reg_loss: 2.7366204261779785
train_trace_term: -16.21941566467285
train_err_1: 0.007740546017885208
train_plan_mass: 1.0000001192092896
  2%|▏         | 6529/300000 [02:40<1:52:27, 43.50it/s]
train_total_loss: -4.320970058441162
train_obj_loss: -7.057686805725098
train_reg_loss: 2.7367167472839355
train_trace_term: -16.231645584106445
train_err_1: 0.007567186374217272
train_plan_mass: 1.000000238418579
  2%|▏         | 6569/300000 [02:41<1:52:29, 43.48it/s]
train_total_loss: -4.321099281311035
train_obj_loss: -7.057918548583984
train_reg_loss: 2.7368195056915283
train_trace_term: -16.24374771118164
train_err_1: 0.007399194873869419
train_plan_mass: 1.000000238418579
  2%|▏         | 6609/300000 [02:42<1:52:41, 43.39it/s]
train_total_loss: -4.321221828460693
train_obj_loss: -7.058132171630859
train_reg_loss: 2.736910343170166
train_trace_term: -16.255783081054688
train_err_1: 0.007229584269225597
train_plan_mass: 1.000000238418579
  2%|▏         | 6649/300000 [02:43<1:52:31, 43.45it/s]
train_total_loss: -4.321343898773193
train_obj_loss: -7.058342933654785
train_reg_loss: 2.736999034881592
train_trace_term: -16.267759323120117
train_err_1: 0.007067727856338024
train_plan_mass: 1.000000238418579
  2%|▏         | 6689/300000 [02:44<1:52:27, 43.47it/s]
train_total_loss: -4.321463584899902
train_obj_loss: -7.058551788330078
train_reg_loss: 2.7370879650115967
train_trace_term: -16.279621124267578
train_err_1: 0.006911196280270815
train_plan_mass: 1.000000238418579
  2%|▏         | 6729/300000 [02:45<1:52:21, 43.50it/s]
train_total_loss: -4.321579933166504
train_obj_loss: -7.058751106262207
train_reg_loss: 2.737170934677124
train_trace_term: -16.291364669799805
train_err_1: 0.006767609156668186
train_plan_mass: 1.000000238418579
  2%|▏         | 6769/300000 [02:45<1:52:21, 43.49it/s]
train_total_loss: -4.32169246673584
train_obj_loss: -7.058950424194336
train_reg_loss: 2.737258195877075
train_trace_term: -16.303010940551758
train_err_1: 0.006624483969062567
train_plan_mass: 1.0000001192092896
  2%|▏         | 6809/300000 [02:46<1:52:29, 43.44it/s]
train_total_loss: -4.321804046630859
train_obj_loss: -7.059152603149414
train_reg_loss: 2.7373485565185547
train_trace_term: -16.314517974853516
train_err_1: 0.006462317891418934
train_plan_mass: 1.000000238418579
  2%|▏         | 6849/300000 [02:47<1:52:19, 43.50it/s]
train_total_loss: -4.321910858154297
train_obj_loss: -7.059355735778809
train_reg_loss: 2.737445116043091
train_trace_term: -16.326082229614258
train_err_1: 0.006320203188806772
train_plan_mass: 1.000000238418579
  2%|▏         | 6889/300000 [02:48<1:52:21, 43.48it/s]
train_total_loss: -4.322016716003418
train_obj_loss: -7.059544563293457
train_reg_loss: 2.737528085708618
train_trace_term: -16.337482452392578
train_err_1: 0.006171904504299164
train_plan_mass: 1.0000001192092896
  2%|▏         | 6929/300000 [02:49<1:52:16, 43.50it/s]
train_total_loss: -4.322117805480957
train_obj_loss: -7.059727668762207
train_reg_loss: 2.73760986328125
train_trace_term: -16.348827362060547
train_err_1: 0.006031166762113571
train_plan_mass: 1.0000001192092896
  2%|▏         | 6969/300000 [02:50<1:52:19, 43.48it/s]
train_total_loss: -4.322215557098389
train_obj_loss: -7.059906005859375
train_reg_loss: 2.7376904487609863
train_trace_term: -16.359909057617188
train_err_1: 0.005879120901226997
train_plan_mass: 1.000000238418579
  2%|▏         | 7009/300000 [02:51<1:52:20, 43.47it/s]
train_total_loss: -4.322311878204346
train_obj_loss: -7.060083389282227
train_reg_loss: 2.737771511077881
train_trace_term: -16.370975494384766
train_err_1: 0.005744688678532839
train_plan_mass: 1.0000001192092896
  2%|▏         | 7049/300000 [02:52<1:52:46, 43.30it/s]
train_total_loss: -4.32240629196167
train_obj_loss: -7.0602569580078125
train_reg_loss: 2.7378506660461426
train_trace_term: -16.38192367553711
train_err_1: 0.005615598987787962
train_plan_mass: 1.000000238418579
  2%|▏         | 7089/300000 [02:53<1:52:16, 43.48it/s]
train_total_loss: -4.3224968910217285
train_obj_loss: -7.060431480407715
train_reg_loss: 2.7379345893859863
train_trace_term: -16.39278221130371
train_err_1: 0.005478468723595142
train_plan_mass: 1.0000001192092896
  2%|▏         | 7129/300000 [02:54<1:52:13, 43.49it/s]
train_total_loss: -4.3225836753845215
train_obj_loss: -7.06060791015625
train_reg_loss: 2.7380242347717285
train_trace_term: -16.40345001220703
train_err_1: 0.005342697724699974
train_plan_mass: 1.0000001192092896
  2%|▏         | 7169/300000 [02:55<1:52:14, 43.48it/s]
train_total_loss: -4.3226704597473145
train_obj_loss: -7.060770034790039
train_reg_loss: 2.7380995750427246
train_trace_term: -16.41346549987793
train_err_1: 0.005197461694478989
train_plan_mass: 1.0000001192092896
  2%|▏         | 7209/300000 [02:56<1:52:15, 43.47it/s]
train_total_loss: -4.322756767272949
train_obj_loss: -7.060966491699219
train_reg_loss: 2.7382094860076904
train_trace_term: -16.423725128173828
train_err_1: 0.005118133500218391
train_plan_mass: 1.0000001192092896
  2%|▏         | 7249/300000 [02:56<1:52:09, 43.50it/s]
train_total_loss: -4.322837829589844
train_obj_loss: -7.061138153076172
train_reg_loss: 2.738300085067749
train_trace_term: -16.434099197387695
train_err_1: 0.004984197206795216
train_plan_mass: 1.0000001192092896
  2%|▏         | 7289/300000 [02:57<1:52:11, 43.48it/s]
train_total_loss: -4.322917938232422
train_obj_loss: -7.061309814453125
train_reg_loss: 2.738391637802124
train_trace_term: -16.444217681884766
train_err_1: 0.004866603296250105
train_plan_mass: 1.0000001192092896
  2%|▏         | 7329/300000 [02:58<1:52:13, 43.46it/s]
train_total_loss: -4.322998046875
train_obj_loss: -7.061473846435547
train_reg_loss: 2.738475799560547
train_trace_term: -16.45431137084961
train_err_1: 0.004745799582451582
train_plan_mass: 1.0000001192092896
  2%|▏         | 7369/300000 [02:59<1:52:30, 43.35it/s]
train_total_loss: -4.3230695724487305
train_obj_loss: -7.061627388000488
train_reg_loss: 2.7385575771331787
train_trace_term: -16.464141845703125
train_err_1: 0.004629617556929588
train_plan_mass: 1.0000001192092896
  2%|▏         | 7409/300000 [03:00<1:52:24, 43.38it/s]
train_total_loss: -4.323143005371094
train_obj_loss: -7.061783790588379
train_reg_loss: 2.738640785217285
train_trace_term: -16.474010467529297
train_err_1: 0.004522358067333698
train_plan_mass: 1.0000001192092896
  2%|▏         | 7449/300000 [03:01<1:52:22, 43.39it/s]
train_total_loss: -4.323217391967773
train_obj_loss: -7.061919212341309
train_reg_loss: 2.7387020587921143
train_trace_term: -16.483299255371094
train_err_1: 0.004429110791534185
train_plan_mass: 1.0000001192092896
  2%|▏         | 7489/300000 [03:02<1:52:21, 43.39it/s]
train_total_loss: -4.323284149169922
train_obj_loss: -7.062070846557617
train_reg_loss: 2.738786458969116
train_trace_term: -16.493276596069336
train_err_1: 0.004312125034630299
train_plan_mass: 1.000000238418579
  3%|▎         | 7529/300000 [03:03<1:52:05, 43.48it/s]
train_total_loss: -4.323351860046387
train_obj_loss: -7.062203407287598
train_reg_loss: 2.73885178565979
train_trace_term: -16.50251007080078
train_err_1: 0.004191157873719931
train_plan_mass: 1.000000238418579
  3%|▎         | 7569/300000 [03:04<1:52:02, 43.50it/s]
train_total_loss: -4.323415279388428
train_obj_loss: -7.062350273132324
train_reg_loss: 2.7389349937438965
train_trace_term: -16.512046813964844
train_err_1: 0.004091998562216759
train_plan_mass: 1.000000238418579
  3%|▎         | 7609/300000 [03:05<1:52:02, 43.49it/s]
train_total_loss: -4.323481559753418
train_obj_loss: -7.062495231628418
train_reg_loss: 2.739013910293579
train_trace_term: -16.52119255065918
train_err_1: 0.003990485332906246
train_plan_mass: 1.0000001192092896
  3%|▎         | 7649/300000 [03:06<1:52:06, 43.46it/s]
train_total_loss: -4.323541164398193
train_obj_loss: -7.0626325607299805
train_reg_loss: 2.739091396331787
train_trace_term: -16.530202865600586
train_err_1: 0.0038997060619294643
train_plan_mass: 1.000000238418579
  3%|▎         | 7689/300000 [03:07<1:52:02, 43.48it/s]
train_total_loss: -4.323598384857178
train_obj_loss: -7.062765121459961
train_reg_loss: 2.739166736602783
train_trace_term: -16.539081573486328
train_err_1: 0.003789302194491029
train_plan_mass: 1.0000001192092896
  3%|▎         | 7729/300000 [03:07<1:51:58, 43.50it/s]
train_total_loss: -4.323659420013428
train_obj_loss: -7.0629072189331055
train_reg_loss: 2.7392477989196777
train_trace_term: -16.54769515991211
train_err_1: 0.003752242773771286
train_plan_mass: 1.0000001192092896
  3%|▎         | 7769/300000 [03:08<1:51:55, 43.51it/s]
train_total_loss: -4.3237152099609375
train_obj_loss: -7.063054084777832
train_reg_loss: 2.7393388748168945
train_trace_term: -16.55678939819336
train_err_1: 0.0035965070128440857
train_plan_mass: 1.0000001192092896
  3%|▎         | 7809/300000 [03:09<1:51:59, 43.48it/s]
train_total_loss: -4.323772430419922
train_obj_loss: -7.0631914138793945
train_reg_loss: 2.7394192218780518
train_trace_term: -16.565143585205078
train_err_1: 0.003528670175001025
train_plan_mass: 1.0000001192092896
  3%|▎         | 7849/300000 [03:10<1:51:56, 43.50it/s]
train_total_loss: -4.323822021484375
train_obj_loss: -7.063319206237793
train_reg_loss: 2.739496946334839
train_trace_term: -16.573375701904297
train_err_1: 0.0034647921565920115
train_plan_mass: 1.0000001192092896
  3%|▎         | 7889/300000 [03:11<1:51:57, 43.48it/s]
train_total_loss: -4.32387638092041
train_obj_loss: -7.063443183898926
train_reg_loss: 2.7395665645599365
train_trace_term: -16.58148193359375
train_err_1: 0.0033959033899009228
train_plan_mass: 1.000000238418579
  3%|▎         | 7929/300000 [03:12<1:52:17, 43.35it/s]
train_total_loss: -4.32392692565918
train_obj_loss: -7.063563346862793
train_reg_loss: 2.7396364212036133
train_trace_term: -16.589771270751953
train_err_1: 0.0032528548035770655
train_plan_mass: 1.0000001192092896
  3%|▎         | 7969/300000 [03:13<1:52:01, 43.45it/s]
train_total_loss: -4.323972225189209
train_obj_loss: -7.063699722290039
train_reg_loss: 2.73972749710083
train_trace_term: -16.597900390625
train_err_1: 0.0033080000430345535
train_plan_mass: 1.000000238418579
  3%|▎         | 8009/300000 [03:14<1:51:52, 43.50it/s]
train_total_loss: -4.324021816253662
train_obj_loss: -7.063820838928223
train_reg_loss: 2.7397990226745605
train_trace_term: -16.60553741455078
train_err_1: 0.0032855719327926636
train_plan_mass: 1.000000238418579
  3%|▎         | 8049/300000 [03:15<1:51:50, 43.50it/s]
train_total_loss: -4.324069023132324
train_obj_loss: -7.063912391662598
train_reg_loss: 2.7398431301116943
train_trace_term: -16.61347770690918
train_err_1: 0.0030339108780026436
train_plan_mass: 1.0000001192092896
  3%|▎         | 8089/300000 [03:16<1:51:51, 43.50it/s]
train_total_loss: -4.324112892150879
train_obj_loss: -7.064017295837402
train_reg_loss: 2.7399041652679443
train_trace_term: -16.62115478515625
train_err_1: 0.0029526264406740665
train_plan_mass: 1.0000001192092896
  3%|▎         | 8129/300000 [03:17<1:51:51, 43.49it/s]
train_total_loss: -4.324160575866699
train_obj_loss: -7.064154624938965
train_reg_loss: 2.7399938106536865
train_trace_term: -16.62879180908203
train_err_1: 0.0031216060742735863
train_plan_mass: 1.000000238418579
  3%|▎         | 8169/300000 [03:17<1:51:47, 43.51it/s]
train_total_loss: -4.32420539855957
train_obj_loss: -7.064235687255859
train_reg_loss: 2.740030527114868
train_trace_term: -16.636173248291016
train_err_1: 0.0028273556381464005
train_plan_mass: 1.0000001192092896
  3%|▎         | 8209/300000 [03:18<1:51:45, 43.51it/s]
train_total_loss: -4.324244499206543
train_obj_loss: -7.064384460449219
train_reg_loss: 2.7401397228240967
train_trace_term: -16.643342971801758
train_err_1: 0.0032296115532517433
train_plan_mass: 1.0000001192092896
  3%|▎         | 8249/300000 [03:19<1:51:49, 43.48it/s]
train_total_loss: -4.324285984039307
train_obj_loss: -7.064435958862305
train_reg_loss: 2.740149974822998
train_trace_term: -16.650711059570312
train_err_1: 0.002684506122022867
train_plan_mass: 1.000000238418579
  3%|▎         | 8289/300000 [03:20<1:51:44, 43.51it/s]
train_total_loss: -4.32432746887207
train_obj_loss: -7.064548492431641
train_reg_loss: 2.740220785140991
train_trace_term: -16.657791137695312
train_err_1: 0.00262312195263803
train_plan_mass: 1.000000238418579
  3%|▎         | 8329/300000 [03:21<1:51:43, 43.51it/s]
train_total_loss: -4.3243632316589355
train_obj_loss: -7.064674377441406
train_reg_loss: 2.7403111457824707
train_trace_term: -16.66484832763672
train_err_1: 0.0028104986995458603
train_plan_mass: 1.000000238418579
  3%|▎         | 8369/300000 [03:22<1:53:27, 42.84it/s]
train_total_loss: -4.324400901794434
train_obj_loss: -7.064740180969238
train_reg_loss: 2.7403390407562256
train_trace_term: -16.671592712402344
train_err_1: 0.002473453525453806
train_plan_mass: 1.0000001192092896
  3%|▎         | 8409/300000 [03:23<1:52:14, 43.30it/s]
train_total_loss: -4.324438095092773
train_obj_loss: -7.064851760864258
train_reg_loss: 2.7404136657714844
train_trace_term: -16.678295135498047
train_err_1: 0.0024328844156116247
train_plan_mass: 1.000000238418579
  3%|▎         | 8449/300000 [03:24<1:51:42, 43.50it/s]
train_total_loss: -4.324470520019531
train_obj_loss: -7.06492805480957
train_reg_loss: 2.74045729637146
train_trace_term: -16.684894561767578
train_err_1: 0.002381914062425494
train_plan_mass: 1.000000238418579
  3%|▎         | 8489/300000 [03:25<1:51:40, 43.51it/s]
train_total_loss: -4.324501991271973
train_obj_loss: -7.064970970153809
train_reg_loss: 2.740468740463257
train_trace_term: -16.691728591918945
train_err_1: 0.002877916442230344
train_plan_mass: 1.000000238418579
  3%|▎         | 8529/300000 [03:26<1:51:41, 43.49it/s]
train_total_loss: -4.3245391845703125
train_obj_loss: -7.06512451171875
train_reg_loss: 2.7405855655670166
train_trace_term: -16.697778701782227
train_err_1: 0.002250146120786667
train_plan_mass: 1.000000238418579
  3%|▎         | 8569/300000 [03:27<1:51:42, 43.48it/s]
train_total_loss: -4.32457160949707
train_obj_loss: -7.065196990966797
train_reg_loss: 2.7406251430511475
train_trace_term: -16.704090118408203
train_err_1: 0.002173601184040308
train_plan_mass: 1.0000001192092896
  3%|▎         | 8609/300000 [03:28<1:51:36, 43.51it/s]
train_total_loss: -4.324600696563721
train_obj_loss: -7.065315246582031
train_reg_loss: 2.7407145500183105
train_trace_term: -16.71026611328125
train_err_1: 0.0024149552918970585
train_plan_mass: 1.0000001192092896
  3%|▎         | 8649/300000 [03:28<1:51:36, 43.51it/s]
train_total_loss: -4.324631214141846
train_obj_loss: -7.065354347229004
train_reg_loss: 2.740723133087158
train_trace_term: -16.71642303466797
train_err_1: 0.0021470217034220695
train_plan_mass: 1.000000238418579
  3%|▎         | 8689/300000 [03:29<1:51:38, 43.49it/s]
train_total_loss: -4.324659824371338
train_obj_loss: -7.065436363220215
train_reg_loss: 2.740776538848877
train_trace_term: -16.722280502319336
train_err_1: 0.0020708555821329355
train_plan_mass: 1.000000238418579
  3%|▎         | 8729/300000 [03:30<1:51:39, 43.48it/s]
train_total_loss: -4.324687957763672
train_obj_loss: -7.065507888793945
train_reg_loss: 2.7408201694488525
train_trace_term: -16.728092193603516
train_err_1: 0.001974385464563966
train_plan_mass: 1.0000001192092896
  3%|▎         | 8769/300000 [03:31<1:51:40, 43.47it/s]
train_total_loss: -4.3247199058532715
train_obj_loss: -7.065642356872559
train_reg_loss: 2.740922451019287
train_trace_term: -16.733491897583008
train_err_1: 0.00235127005726099
train_plan_mass: 1.000000238418579
  3%|▎         | 8809/300000 [03:32<1:51:36, 43.49it/s]
train_total_loss: -4.324747085571289
train_obj_loss: -7.065666198730469
train_reg_loss: 2.740919351577759
train_trace_term: -16.739151000976562
train_err_1: 0.0018935303669422865
train_plan_mass: 1.000000238418579
  3%|▎         | 8849/300000 [03:33<1:51:36, 43.48it/s]
train_total_loss: -4.324771881103516
train_obj_loss: -7.065713882446289
train_reg_loss: 2.7409420013427734
train_trace_term: -16.744436264038086
train_err_1: 0.0020044823177158833
train_plan_mass: 1.0000001192092896
  3%|▎         | 8889/300000 [03:34<1:51:31, 43.51it/s]
train_total_loss: -4.324799060821533
train_obj_loss: -7.065844535827637
train_reg_loss: 2.7410454750061035
train_trace_term: -16.749774932861328
train_err_1: 0.002001760061830282
train_plan_mass: 1.0000001192092896
  3%|▎         | 8929/300000 [03:35<1:51:30, 43.50it/s]
train_total_loss: -4.324821472167969
train_obj_loss: -7.06585693359375
train_reg_loss: 2.7410354614257812
train_trace_term: -16.755348205566406
train_err_1: 0.0019605220295488834
train_plan_mass: 1.000000238418579
  3%|▎         | 8969/300000 [03:36<1:51:30, 43.50it/s]
train_total_loss: -4.324847221374512
train_obj_loss: -7.065942764282227
train_reg_loss: 2.741095542907715
train_trace_term: -16.76028823852539
train_err_1: 0.0017270536627620459
train_plan_mass: 1.0000001192092896
  3%|▎         | 9009/300000 [03:37<1:51:24, 43.53it/s]
train_total_loss: -4.324868679046631
train_obj_loss: -7.066011428833008
train_reg_loss: 2.741142749786377
train_trace_term: -16.76513671875
train_err_1: 0.0017490170430392027
train_plan_mass: 1.000000238418579
  3%|▎         | 9049/300000 [03:38<1:51:26, 43.51it/s]
train_total_loss: -4.324889659881592
train_obj_loss: -7.066054344177246
train_reg_loss: 2.7411646842956543
train_trace_term: -16.770225524902344
train_err_1: 0.0018150557298213243
train_plan_mass: 1.000000238418579
  3%|▎         | 9089/300000 [03:39<1:51:28, 43.49it/s]
train_total_loss: -4.324913501739502
train_obj_loss: -7.066171646118164
train_reg_loss: 2.741258144378662
train_trace_term: -16.774620056152344
train_err_1: 0.001733657787553966
train_plan_mass: 1.000000238418579
  3%|▎         | 9129/300000 [03:39<1:51:26, 43.50it/s]
train_total_loss: -4.324936389923096
train_obj_loss: -7.066231727600098
train_reg_loss: 2.741295337677002
train_trace_term: -16.779170989990234
train_err_1: 0.0017399305943399668
train_plan_mass: 1.0000001192092896
  3%|▎         | 9164/300000 [03:40<1:52:09, 43.22it/s]
train_total_loss: -4.324957847595215
train_obj_loss: -7.066274642944336
train_reg_loss: 2.7413170337677
train_trace_term: -16.783672332763672
train_err_1: 0.0015495780389755964
train_plan_mass: 1.0000001192092896
  3%|▎         | 9209/300000 [03:41<1:51:58, 43.28it/s]
train_total_loss: -4.324978828430176
train_obj_loss: -7.066302299499512
train_reg_loss: 2.741323471069336
train_trace_term: -16.78818130493164
train_err_1: 0.0016214671777561307
train_plan_mass: 1.000000238418579
  3%|▎         | 9249/300000 [03:42<1:51:21, 43.51it/s]
train_total_loss: -4.324997901916504
train_obj_loss: -7.066372871398926
train_reg_loss: 2.7413747310638428
train_trace_term: -16.792335510253906
train_err_1: 0.0016133959870785475
train_plan_mass: 1.0000001192092896
  3%|▎         | 9289/300000 [03:43<1:51:21, 43.51it/s]
train_total_loss: -4.325017929077148
train_obj_loss: -7.066423416137695
train_reg_loss: 2.7414052486419678
train_trace_term: -16.796737670898438
train_err_1: 0.0016729895723983645
train_plan_mass: 1.000000238418579
  3%|▎         | 9329/300000 [03:44<1:51:19, 43.52it/s]
train_total_loss: -4.325038909912109
train_obj_loss: -7.066522598266602
train_reg_loss: 2.7414839267730713
train_trace_term: -16.800338745117188
train_err_1: 0.001481073908507824
train_plan_mass: 1.0000001192092896
  3%|▎         | 9369/300000 [03:45<1:51:21, 43.50it/s]
train_total_loss: -4.325055122375488
train_obj_loss: -7.066555023193359
train_reg_loss: 2.74150013923645
train_trace_term: -16.804317474365234
train_err_1: 0.0014166823821142316
train_plan_mass: 1.0000001192092896
  3%|▎         | 9409/300000 [03:46<1:51:38, 43.38it/s]
train_total_loss: -4.3250732421875
train_obj_loss: -7.066601753234863
train_reg_loss: 2.7415285110473633
train_trace_term: -16.80881118774414
train_err_1: 0.0017240408342331648
train_plan_mass: 1.000000238418579
  3%|▎         | 9449/300000 [03:47<1:51:20, 43.49it/s]
train_total_loss: -4.325095176696777
train_obj_loss: -7.0666656494140625
train_reg_loss: 2.741570472717285
train_trace_term: -16.812110900878906
train_err_1: 0.0012820642441511154
train_plan_mass: 1.000000238418579
  3%|▎         | 9489/300000 [03:48<1:51:17, 43.50it/s]
train_total_loss: -4.325109481811523
train_obj_loss: -7.066751480102539
train_reg_loss: 2.7416419982910156
train_trace_term: -16.815570831298828
train_err_1: 0.0016491971909999847
train_plan_mass: 1.000000238418579
  3%|▎         | 9529/300000 [03:49<1:51:32, 43.40it/s]
train_total_loss: -4.325126647949219
train_obj_loss: -7.066771507263184
train_reg_loss: 2.741645097732544
train_trace_term: -16.819108963012695
train_err_1: 0.0012057197745889425
train_plan_mass: 1.000000238418579
  3%|▎         | 9569/300000 [03:50<1:51:35, 43.38it/s]
train_total_loss: -4.325142860412598
train_obj_loss: -7.06680965423584
train_reg_loss: 2.741666555404663
train_trace_term: -16.822601318359375
train_err_1: 0.0012829561019316316
train_plan_mass: 1.000000238418579
  3%|▎         | 9609/300000 [03:50<1:51:17, 43.49it/s]
train_total_loss: -4.32515811920166
train_obj_loss: -7.066872596740723
train_reg_loss: 2.7417147159576416
train_trace_term: -16.82563018798828
train_err_1: 0.0014641948509961367
train_plan_mass: 1.000000238418579
  3%|▎         | 9649/300000 [03:51<1:51:17, 43.48it/s]
train_total_loss: -4.325170993804932
train_obj_loss: -7.066902160644531
train_reg_loss: 2.7417311668395996
train_trace_term: -16.828948974609375
train_err_1: 0.001414578640833497
train_plan_mass: 1.000000238418579
  3%|▎         | 9689/300000 [03:52<1:51:13, 43.50it/s]
train_total_loss: -4.325189590454102
train_obj_loss: -7.066948890686035
train_reg_loss: 2.7417593002319336
train_trace_term: -16.832250595092773
train_err_1: 0.0010891151614487171
train_plan_mass: 1.0000001192092896
  3%|▎         | 9729/300000 [03:53<1:51:11, 43.51it/s]
train_total_loss: -4.325202941894531
train_obj_loss: -7.066989898681641
train_reg_loss: 2.7417867183685303
train_trace_term: -16.835376739501953
train_err_1: 0.0011326021049171686
train_plan_mass: 1.000000238418579
  3%|▎         | 9769/300000 [03:54<1:51:13, 43.49it/s]
train_total_loss: -4.3252153396606445
train_obj_loss: -7.067033767700195
train_reg_loss: 2.7418181896209717
train_trace_term: -16.837806701660156
train_err_1: 0.0014794000890105963
train_plan_mass: 1.0000001192092896
  3%|▎         | 9809/300000 [03:55<1:51:10, 43.50it/s]
train_total_loss: -4.325231075286865
train_obj_loss: -7.067059516906738
train_reg_loss: 2.741828441619873
train_trace_term: -16.841217041015625
train_err_1: 0.0010220026597380638
train_plan_mass: 1.000000238418579
  3%|▎         | 9849/300000 [03:56<1:51:29, 43.38it/s]
train_total_loss: -4.3252458572387695
train_obj_loss: -7.067102432250977
train_reg_loss: 2.741856575012207
train_trace_term: -16.843894958496094
train_err_1: 0.0010501404758542776
train_plan_mass: 1.0000001192092896
  3%|▎         | 9889/300000 [03:57<1:51:28, 43.37it/s]
train_total_loss: -4.325257301330566
train_obj_loss: -7.067130088806152
train_reg_loss: 2.741872549057007
train_trace_term: -16.846837997436523
train_err_1: 0.001272382098250091
train_plan_mass: 1.0000001192092896
  3%|▎         | 9920/300000 [03:58<1:56:02, 41.67it/s]
train_total_loss: -4.325271129608154
train_obj_loss: -7.067163467407227
train_reg_loss: 2.7418923377990723
train_trace_term: -16.849288940429688
train_err_1: 0.0009368073660880327
train_plan_mass: 1.0000001192092896
Converged after 9921 iterations
train_n_iterations: 9921
train_err_1: 0.0009371092892251909
train_plan_mass: 1.0
train_A_final_norm: 274.18896484375
train_plan_entropy_ratio: tensor([0.7620], device='cuda:0', grad_fn=<DivBackward0>)
A shape: (4031, 134)
A device: cuda:0
[13]:
train_total_loss train_obj_loss train_reg_loss train_trace_term train_err_1 train_plan_mass
245 -4.325231 -7.067060 2.741828 -16.841217 0.001022 1.0
246 -4.325246 -7.067102 2.741857 -16.843895 0.001050 1.0
247 -4.325257 -7.067130 2.741873 -16.846838 0.001272 1.0
248 -4.325271 -7.067163 2.741892 -16.849289 0.000937 1.0
249 NaN NaN NaN NaN 0.000937 1.0

Transport pseudo-unpaired cells and aggregate the plan

After fitting, we use the learned cost to transport pseudo-unpaired RNA cells to pseudo-unpaired ADT cells.

[14]:
transport = model.transport(
    {RNA_MOD: adata_rna, ADT_MOD: adata_adt},
    x_reps={RNA_MOD: "layers/scaled", ADT_MOD: "layers/scaled"},
    store_cost=True,
    store_plan=True,
    max_iter_sink=20 if DEBUG else 1000,
)

plan = transport.materialize_plan().detach().cpu().numpy()
print("plan shape:", plan.shape)
print("plan mass:", plan.sum())
transport.plan_diagnostics
plan shape: (2183, 2183)
plan mass: 1.0000002
[14]:
{'mass': 1.000000238418579,
 'mass_abs_error': 2.384185791015625e-07,
 'x_1_marginal_l1_error': 0.0008025987190194428,
 'x_2_marginal_l1_error': 2.4001928977668285e-07,
 'finite': True,
 'passed': True,
 'tol': 0.001}

Examine the transport plan

For a quick visualization of the validity of the inferred transport plan, we can use provided cell type annotations to aggregate the transported mass across those labels and check mass is transported between coherent cell populations.

[15]:
labels_rna = adata_rna.obs[CELLTYPE_KEY].astype(str).to_numpy()
labels_adt = adata_adt.obs[CELLTYPE_KEY].astype(str).to_numpy()

aggregated_plan, fig, ax = plot_aggregated_transport_plan(
    heat_mtx=plan,
    annotations=labels_rna,
    annotations_ordered=CELLTYPE_ORDER,
    annotations_2=labels_adt,
    annotations_ordered_2=CELLTYPE_ORDER,
    cbar_label="transport mass",
    reduction="sum",
)
../../_images/tutorials_notebooks_cite_interpretability_tutorial_14_0.png

Inspect one marker-specific column of A

A column of A ranks RNA genes by their contribution to the learned cost for one ADT protein. Here we inspect one familiar B-cell marker as an example of how to query and plot the learned interactions.

[16]:
MARKER_PROTEIN = "CD19"
top = model.top_interactions(
    MARKER_PROTEIN,
    modality=ADT_MOD,
    k=10,
    direction="positive",
)
display(
    top[["target_feature", "weight"]].rename(
        columns={"target_feature": "gene", "weight": "A_score"},
    )
)
bar_fig = top_interactions_bar(
    interactions=top,
    title=f"Top positive interactions with {MARKER_PROTEIN}",
    colors="#CA4B55",
)
gene A_score
0 CD24 0.021163
1 WHAMM 0.019635
2 CXXC5 0.019152
3 LINC00926 0.017681
4 MS4A1 0.015109
5 PDLIM1 0.014692
6 BLK 0.014095
7 PAX5 0.013725
8 TRIM59 0.013483
9 ZNF284 0.013372
../../_images/tutorials_notebooks_cite_interpretability_tutorial_16_1.png

Enrichment from one protein-specific A column

The same A column can also be used to define a gene set for enrichment. This is only a small example of the downstream analyses enabled by the learned interaction matrix; it requires internet access for g:Profiler.

[17]:
def genes_for_enrichment(A_df, protein, frac_max=100, max_genes=100):
    if protein not in A_df.columns:
        raise KeyError(f"Protein {protein!r} was not found in A columns.")
    scores = A_df[protein]
    max_abs_score = scores.abs().max()
    if max_abs_score == 0:
        return pd.Series(dtype=float, name=protein)
    selected = scores[scores.abs() > (max_abs_score / frac_max)]
    selected = selected.loc[selected.abs().sort_values(ascending=False).index]
    selected = selected.iloc[:max_genes]
    selected.name = protein
    return selected
[18]:
ENRICHMENT_PROTEIN = "CD18"
gp = GProfiler(return_dataframe=True)
background = list(A_df.index)

enrichment_input = genes_for_enrichment(
    A_df,
    protein=ENRICHMENT_PROTEIN,
    frac_max=100,
    max_genes=100,
)
print(f"{ENRICHMENT_PROTEIN}: top {len(enrichment_input)} genes")
display(enrichment_input.to_frame("A_score").head(20))
CD18: top 100 genes
A_score
CD84 0.029304
AC092279.1 -0.026720
ALOX5AP 0.026240
PCNA 0.023267
KBTBD2 0.022768
NRIP1 -0.021693
AC007262.2 0.021638
CD74 0.021350
ITGB2 0.020158
CA5B 0.019774
ISCA1 0.019579
IRF8 0.018257
PGD -0.017852
SAMD9L 0.017842
ZFAND2B 0.017396
IGSF6 0.017214
GZMH 0.017020
AC010168.2 -0.016554
MTIF2 -0.016174
TANGO2 0.015917
[19]:
enrichment_results = gp.profile(
    organism="hsapiens",
    sources=["KEGG", "REAC", "WP", "GO"],
    query=list(enrichment_input.index),
    background=background,
)
display(enrichment_results)
source native name p_value significant description term_size query_size intersection_size effective_domain_size precision recall query parents
0 WP WP:WP615 Senescence and autophagy in cancer 0.024924 True Senescence and autophagy in cancer 31 87 8 3521 0.091954 0.258065 query_1 [WP:000000]
[ ]: