Jaccardovo tajemství – jak počítat podobnost množin pomalu, jak ji počítat rychle a jak při výpočtu podvádět

Jaccardův index podobnosti je jednoduchá funkce, která udává míru podobnosti mezi dvěma množinami. Je definována jako velikost průniku vydělená velikostí sjednocení dvou množin.

J(A, B) = |A ∩ B| / |A ∪ B|

Funkce je to jednoduchá. Otázka je, jak ji implementovat, aby běžela rychle. V následujících odstavcích se vydám na cestu za největší efektivitou za každou cenu. A když říkám za každou cenu, myslím tím, že se skutečně před ničím nezastavím.

Když půjde všechno podle plánu, cestou se možná dostaví jeden nebo dva momenty osvícení.


Naivní implementace ve Scale by mohla vypadat nějak takhle:

def jacc(a: Set[Int], b: Set[Int]): Double = {
  val is = a.intersection(b).size
  val us = a.union(b).size
  if (us == 0) 0 else is.toDouble / us
}

Jednoduchý kód, strašlivý výkon. Problém spočívá v tom, že je třeba vytvořit dvě množiny jen proto, abych zjistit jejich velikost. Velikost sjednocení není třeba vůbec počítat, protože se dá jednoduše odvodit z principu inkluze a exkluze.

def jacc(a: Set[Int], b: Set[Int]): Double = {
  val is = a.intersection(b).size
  val us = a.size + b.size - is
  if (us == 0) 0 else is.toDouble / us
}

To je lepší, ale pořád je třeba vytvořit jednu množinu se všemi alokacemi a interními režiemi, které to obnáší. Logickým krokem je nic nealokovat a v jedné iteraci spočítat velikost průniku.

def jacc(a: Set[Int], b: Set[Int]): Double = {
  val small = (if (a.size < b.size) a else b
  val big   = (if (b.size < a.size) a else b
  val is = small.count { el => big.contains(el) }
  val us = a.size + b.size - is
  if (us == 0) 0 else is.toDouble / us
}

To je lepší, ale zdaleka ne ideální. Problém může představovat uspořádání dat a layout paměti. V případě JVM má generický HashSet celkem velkou režii a mizernou lokalitu3. Set[Int] neuchovává primitivní čtyřbajtové inty, ale reference na boxované Integer objekty. Kombinace ref+box může zabírat klidně 32 bajtů na 64-bitovém systému a musí udělat jednu dereferenci pointeru.

Tomu se dá vyhnout používáním jazyka/runtime, který dělá specializaci typů (reifikovaná generikav C# nebo C++ šablony) nebo kolekcemi specializovanými pro primitivní typy. Na JVM je k dispozici několik takových knihoven a jedna z nejlepších je Koloboke.

import net.openhft.koloboke.collect.set.hash.HashIntSet

def jacc(a: HashIntSet, b: HashIntSet): Double = {
  val small = (if (a.size < b.size) a else b
  val big   = (if (b.size < a.size) a else b

  var is = 0
  val cur = small.cursor
  while (cur.moveNext()) {
    if (big.contains(cur.elem)) {
      is += 1
    }
  }

  val us = a.size + b.size - is
  if (us == 0) 0 else is.toDouble / us
}

Kód je o něco delší, ale na druhou stranu může být mnohem rychlejší. Data jsou uložena v plochých polích a nepotřebují nahánět pointery.

Všechny předchozí změny představovaly pokrok v mezích zákona, pozvolné zlepšování jednoho řešení. Nešlo však o žádné radikální skoky vpřed. Těch můžu dosáhnout jedině, když to vezmu z druhého konce a začnu přemýšlet o tom, co je skutečně potřeba. V tomto případě mě zajímá jen Jaccardova podobnost, nic jiného. Všechny reprezentace množin, které jsem doteď používal byly založeny na hash tabulkách a nabízely tedy mnoho jiné funkcionality. Uměly například v konstantním čase zjistit, zdali se daný element nachází v množině. Já však potřebuji jen rychlý výpočet velikosti průniku. Když budu reprezentovat množinu seřazeným polem, dá se právě tahle veličina spočítat velice efektivně.

def jacc(a: Array[Int], b: Array[Int]): Double = {
  def intersectionSize(a, b) = {
    var ai, bi, size = 0
    while (ai < a.length && bi < b.length) {
      val av = a(ai)
      val bv = b(bi)
      size += (if (av == bv) 1 else 0)
      ai   += (if (av <= bv) 1 else 0)
      bi   += (if (av >= bv) 1 else 0)
    }
    size
  }

  val is = intersectionSize(a, b)
  val us = a.length + b.length - is
  if (us == 0) 0 else is.toDouble / us
}

Stačí velice jednoduchá smyčka, která udělá lineární průchod oběma poli a nepotřebuje dělat žádné hledání v interních hash tabulkách. Tři řádky ve tvaru x += (if (cond) 1 else 0) kompilátoru/JITu silně naznačují, aby místo podmíněných skoků použil cmov instrukce1. To odstraní potenciální nepředvídatelný skok v těle smyčky, který by všechno mohl výrazně zpomalit.

Tento kód je nejen velice rychlý, ale také překvapivě jednoduchý. Funguje tak, že hledá shodné prvky v poli. Když je najde, inkrementuje proměnnou size a i oba indexy. V ostatních případech inkrementuje index, který ukazuje na menší prvek. Jde o algoritmus velice podobný merge sortu.

Tělo smyčky obsahuje pouhých ±20 instrukcí a nepůjde zrychlit redukcí počtu operací v jedné iteraci, ale zmenšením počtu iterací.

Jedním způsobem jak toho dosáhnout, je dívat se n míst dopředu4 s tím, že když najdu element menší než ten hledaný, můžu přeskočit n míst a tím pádem i n iterací. Tělo smyčky se trochu zkomplikuje, výsledek však často je o pár desítek procent rychlejší a jen málokdy dojde ke zpomalení. Pokud je skok malý, kód přeskakuje jen malé úseky a neušetří příliš iterací. Na druhou stranu, když je skok velký, kód nemůže nic přeskočit a iteruje jako normální verze. Je potřeba najít nějaké přijatelné optimum.

while (ai < alen && bi < blen) {
  val av = a(ai)
  val bv = b(bi)
  val _ai = ai
  val _bi = bi
  size += (if (av == bv) 1 else 0)
  ai   += (if (av <= bv) (if (a(_ai+skip) < bv) skip else 1) else 0)
  bi   += (if (av >= bv) (if (b(_bi+skip) < av) skip else 1) else 0)
}

Pokud bych chtěl jít ještě dál, mohl bych pole předzpracovat a za každý element vložit předpočítanou bitmapu obsahující osm čtyřbitových čísel, které udávají, jak daleko může jeden index poskočit v závislosti na rozdílu porovnávaných hodnot. Za zmínku stojí, že kód obsahuje jen rychlé bitové operace a nepotřebuje žádný podmíněný extra skok. Je třeba jen něco přes 10 instrukcí navíc.

def intersectionSizeWithEmbeddedSkiplists(a: Array[Int], b: Array[Int]): Int = {
  var size, ai, bi = 0
  while (ai != a.length && bi != b.length) {
    val av = a(ai)
    val bv = b(bi)

    val s  = (if (av < bv) av else bv)
    val si = (if (av < bv) ai else bi)

    val d = java.lang.Math.abs(av - bv)
    val bits = 32 - Long.numberOfLeadingZeros(d)
    val slot = bits / 4

    val slotval = (s(si+1) >>> (slot * 4)) & ((1 << 4) - 1)
    val skip = slotval << (slot - 1)

    size += (if (av == bv) 1 else 0)
    ai   += (if (av <= bv) skip else 0)
    bi   += (if (av >= bv) skip else 0)
  }
  size
}

O kolik nebo jestli vůbec to zrychlí výsledek jsem ale netestoval, protože mi došla kuráž, když jsem začal přemýšlet jak napsat funkci, která vypočítá možné skoky.

Ale to stále není všechno. Na začátku jsem psal za každou cenu a pořád to myslím vážně.

Když zajdu do extrému, je možné implementovat Jaccarda pomocí AVX2 SIMD instrukcí2, které prohledávají vektor osmi hodnot paralelně. Jak takováto hrůza vypadá se můžete přesvědčit na vlastní oči v tomto gistu. Vektorizované řešení je v nejhorším případě, kdy nikdy není možné přeskočit několik iterací, 2× pomalejší než přímočará implementace (protože potřebuje vykonat víc instrukcí a každá iterace dělá víc práce), ale v nejlepším případě, kdy může často přeskakovat velký kus vstupního pole, až 4× rychlejší.

Pro další zrychlení je možné na začátku kontrolovat jestli je začátek jednoho pole větší než konec toho druhého. V takovém případě je jasné, že množiny nemají žádný společný prvek a Jaccardova podobnost bude vždycky 0. Ke stejnému účelu se dá použít bitmapa (např. 64 bitů) fungující jako maličký Bloom filtr. Když udělám logický and dvou bitmap a dostanu 0 (tj. dvě mapy nemají žádné společné bity), je jasné, že dvě množiny, ze kterých byly tyto bitmapy odvozeny nemají žádné společné prvky a Jaccardova podobnost bude opět nulová.

if (b(0) < a(a.length-1) ||
    a(0) < b(b.length-1) ||
    (aBitmap & bBitmap) == 0) {
  return 0.0
}

To ale stále není všechno. Když mě nezajímá přesná Jaccardova podobnost, ale vystačím si s odhadem, můžu použít MinHash. Ten produkuje jen přibližné výsledky5, ale může být výrazně rychlejší, protože nepočítá podobnost mezi celými množinami, ale jen jejich otisky, které mají fixní velikost.

MinHash a mnoho dalších skečů jsem implementoval v knihovně sketches. S ní se dá spočítat odhad podobnosti velice jednoduše:

val mh = atrox.sketches.MinHash(sets, 128)
mh.estimateSimilarity(i, j)

Když ani tohle nestačí, pomůže už jen locality sensitive hashing (LSH) (viz Mining of Massive Datasets, kapitola 3.3 a 3.4)6. LSH může výpočet podobnosti výrazně zrychlit, protože omezuje hledání jen na kandidáty, které jsou s velkou pravděpodobností podobné a zcela přeskočí ty, které jsou (opět s velkou pravděpodobností) nepodobné.

A knihovna sketches také obsahuje implementaci LSH.

val mh = atrox.sketches.MinHash(sets, 128)
val lsh = atrox.sketches.LSH(mh, bands = 32)
lsh.similarItems(i, similarityThreshold)

S vhodně nastavenou LSH je možné úlohu, která by trvala 24 hodin hrubou silou, spočítat za 24 vteřin s celkem rozumnou ztrátou přesnosti.

Myslím, že rychleji než tohle už to není možné.


Dále k tématu:


  1. Technicky vzato kompilátor může tento řádek přeložit na trojici instrukcí cmp, setX a add a ani nepotřebuje cmov.
  2. Síla SIMD operací je vidět v Někdy je nejchytřejší nedělat nic chytrého
  3. V případě neměnných kolekcí je to ještě horší, protože jsou interně implementované jako hash array mapped trie a to s sebou přináší další úrovně pointerů a referencí.
  4. viz Introduction to information retrieval, kapitola 2.3: Faster posting list intersection via skip pointers
  5. Někoho tady může napadnout, že když mi stačí odhady, můžu použít HyperLogLog k odhadnutí velikosti sjednocení a pak dopočítat průnik, před princip inkluze a exkluze. To funguje, ale není to příliš přesné, protože chyba je relativní vzhledem k velikosti sjednocení a nikoli k průniku, který může být mnohem menší.
  6. Když mluvím o LSH, měl bych se také zmínit, že existuje alternativní přístup hledání nejbližších sousedů, který není postavený na hashování, ale na binárních stromech. Přibližně to odpovídá rozdílu mezi hash tabulkami a binárními vyhledávacími stromy – hashování nabízí O(1) hledání, stromy hledají v čase O(log n), ale jejich obsah je seřazený. To v případě hledání nejbližších sousedů znamená, že si můžu říct o body, které jsou trochu dál, což LSH nedokáže.

Flattr this!

This entry was posted in Algo, DS, Scala. Bookmark the permalink.

3 Responses to Jaccardovo tajemství – jak počítat podobnost množin pomalu, jak ji počítat rychle a jak při výpočtu podvádět

  1. Aleš Hájek says:

    A nebylo by jednodušší vytvořit dvojice (množina, hodnota) ty setřídit podle hodnoty a pak jedním průchodem vypsat výsledek, než tato hrůza? Navíc tuto jednoduchou databázovou operaci zvládne každá databáze pro miliony prvků do jedné sekundy.

    • Nevím jak přesně se tohle týká Jaccardovy podobnosti, tak se přikloním k tomu, že by to nebylo jednodušší. Co přesně má být ta hodnota? Jestli jde o prvky obou množin, které jsou označkované do jaké množiny patří, pak by to fungovalo, ale je to zbytečná práce a alokace navíc. Ta hrůza ve své podstatě dělá jednu iteraci merge sortu, která ze dvou seřazených polí vyrobí další pole až na to, že výsledek není nikdy materializovaný, ale okamžitě je zredukován na velikost průniku.

      Ta hrůza umí spočítat Jaccarda/průnik dvou množin, z nichž každá má milion prvků, za 8.4 milisekundy.

Leave a Reply

Your email address will not be published. Required fields are marked *