Skip to main content
Became Hot Network Question
edited title
Link

Canonical Shannon coding implementation

Source Link

Canonical Shannon coding implementation

I recently learned about Shannon coding and decided to implement it, and I’d appreciate any feedback. One unusual aspect of this implementation is that the lookup table it returns in the end serves as the list of character frequencies / probabilities during the computation. I do this by initializing each table entry’s Length field with its own offset, which lets me sort the table after counting probabilities without a separate array on the stack or heap. In the final step, I zero out the lengths of non-occurring characters and swap entries back to their original offsets with the correct length values.

struct Code { uint8_t Length; uint32_t Bits; };
using CodeTable = std::array<Code, 256>;

//Returns a lookup table for symbols in the input.
CodeTable ComputeTable(std::string_view input)
{
  if(input.length() >= 1ULL<<32 )
    throw std::length_error{"input length exceeds 2^32"};

  if(!input.length()) return {};

  //"CodeTable" is used but fields are temporarily used differently (Length=symbol, Bits=frequency)
  CodeTable syms;
  std::generate(syms.begin(), syms.end(), [c=uint8_t(0)]() mutable { return Code{c++, 0}; });

  //count symbol frequencies
  for(unsigned char c : input) ++syms[c].Bits;

  //sort symbols by frequency
  std::sort(syms.begin(), syms.end(), [](auto a, auto b){ return a.Bits > b.Bits; });

  //convert frequencies to probabilities
  for(auto i = 0; syms[i].Bits; ++i)
  {
    syms[i].Bits = (static_cast<uint64_t>(syms[i].Bits) << 32) / input.length();
    //in case symbol has 100% prob the fractional part would be 0
    if(!syms[i].Bits) syms[i].Bits = (1ULL<<32)-1;
  }

  //iterate backwards to zero out length of non occuring symbols
  for(size_t i = syms.size(); !syms[--i].Bits;) syms[i].Length = 0;

  //compute cumulative probability
  uint32_t cum_p = 0;
  for(size_t i = 0; syms[i].Bits; ++i)
  {
    uint32_t sym_p  = syms[i].Bits;
    uint8_t  symbol = syms[i].Length;

    //compute code length as l = -ceil(log_2(sym_p)) /AKA n leading 0s + 1
    uint8_t lz = 0;
    while( !(sym_p & (1<<(31-lz))) ) if(++lz >= 32) assert(false);
    syms[i].Length = lz+1;

    //set Bits to the current cumulative probability fraction
    syms[i].Bits = cum_p;

    //increment the cumulative probability
    cum_p += sym_p;

    //swap this entry back to its symbol offset in the table
    if(i > symbol || !syms[symbol].Bits) std::swap(syms[i], syms[symbol]);
  }

  return syms;
}