Type Safe Linear Algebra in Scala

Thanks to Scala, lately I’ve been appreciating type safety more and more when programming. It was an adjustment coming from Python, R, and C, but the performance benefits and the fact that I can be pretty sure that when my code compiles, it will run properly means that I can deploy code with much higher confidence.

However, there’s one area of my development life where type safety hasn’t done much for me – specifically numerical linear algebra and, by consequence, machine learning. In this post I’ll explain what that problem is, and propose a solution to backport type safety onto linear algebra operations in Scala, in a non-intrusive way.

The Problem

Anyone who has taken a basic linear algebra class or played around with numerical code knows about dimension alignment - in python it looks like this:

>>> np.random.rand(2,2) * np.random.rand(3,1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (2,2) (3,1) 

In Scala, using the awesome breeze library, it looks like this:

scala> import breeze.linalg._
import breeze.linalg._

scala> DenseMatrix.rand(2,2) * DenseMatrix.rand(3,1)
java.lang.IllegalArgumentException: requirement failed: Dimension mismatch!
	at scala.Predef$.require(Predef.scala:233)
	at breeze.linalg.operators.DenseMatrixMultiplyStuff$implOpMulMatrix_DMD_DMD_eq_DMD$.apply(DenseMatrixOps.scala:45)
	at breeze.linalg.operators.DenseMatrixMultiplyStuff$implOpMulMatrix_DMD_DMD_eq_DMD$.apply(DenseMatrixOps.scala:40)
	at breeze.linalg.NumericOps$class.$times(NumericOps.scala:261)
	at breeze.linalg.DenseMatrix.$times(DenseMatrix.scala:54)

That is - if you want to multiply two matrices, their dimensions have to match up in the right way. An (n x d) matrix can only be multiplied on the left by a matrix that’s (something x n) or on the right by a matrix that’s (d x something).

There’s something to notice about the errors above. First, they’re data dependent. Multiplying a (3 x 2) by a (2 x 1) matrix wouldn’t have such disastrous effects, but change the inner dimension, and suddenly you have problems. Second, they’re runtime errors. Meaning that, especially in the case of Scala, the code will compile just fine, and we will only encounter this error at runtime. Isn’t this what the compiler is supposed to figure out for us?

Matrix-matrix multiplication is at the very heart of advanced analytics, machine learning, and high performance scientific computing - so it’s comes up in complicated and non-trivial ways at the center of some very complicated algorithms. I can’t tell you the number of bugs I’ve hit because I forgot to transpose or because I assumed that the data was coming in in one shape and in fact it came in in another, and I believe this to be a common experience among programmers of algorithms like this. Heck - even the theoreticians will tell you that half the work in checking their proofs for correctness is making sure that the dimensions line up. (I kid, but only a little.)

A Solution

So how do we avoid this mess and get the type system to check our dimensions for us? If you came to this post hoping to read about Algebraic Data Types and Monads, then I’m sorry, this is not the post you were hoping for. Instead, I’ll propose a simple solution that does the trick without resorting to type system black magic.

The basic observation here is twofold:

  1. Usually people work with a relatively small number of dimensions. That is, I might have “N” data points with “D” features in “K” classes - while each of these numbers might itself be big, there are only 3 of them to keep track of, and I kind of know that my data is going to be (N x D) and my model is going to be (D x K), for example.
  2. By forcing the user to provide just a little more information to the type system, we can get type safety for linear algebra in a sensible way.

So, now for the code - first, let’s define a Matrix type that contains two type parameters - A and B, which has some basic operations:

import breeze.linalg._

class Matrix[A,B](val mat: DenseMatrix[Double]) {
    def *[C](other: Matrix[B,C]): Matrix[A,C] = new Matrix[A,C](mat*other.mat)
    def t: Matrix[B,A] = new Matrix[B,A](mat.t)
    def +(other: Matrix[A,B]): Matrix[A,B] = new Matrix[A,B](mat + other.mat)
    def :*(other: Matrix[A,B]): Matrix[A,B] = new Matrix[A,B](mat :* other.mat)
    def *(scalar: Double): Matrix[A,B] = new Matrix[A,B](mat * scalar)

Additionally, I’ll create some helper functions - one to read data in from file and the other to invert a square matrix:

object MatrixUtils {
  def readcsv[A,B](filename: String) = new Matrix[A,B](csvread(new java.io.File(filename)))
  def inverse[A](x: Matrix[A,A]): Matrix[A,A] = new Matrix[A,A](inv(x.mat))
  def ident[D](d: Int): Matrix[D,D] = new Matrix[D,D](DenseMatrix.eye(d))

So let’s see it in action:

import MatrixUtils._

class N
class D
class K

val x = new Matrix[N,D](DenseMatrix.rand(100,10))
val y = new Matrix[N,K](DenseMatrix.rand(100,2))

val z1 = x * x //Does not compile!
val z2 = x.t * y //Compiles! Returns a Matrix[D,K]
val z3 = x.t * x //Compiles! Returns a Matrix[D,D]
val z4 = x * x.t //Compiles! Returns a Matrix[N,N]

What have we done her? We’ve first defined some classes to represent our dimensions (which are abstract) - then we’ve created some matrices and assigned labels to these dimensions. We could just has easily have read x or y from file - provided we knew their intended shapes.

Finally, we tried some basic linear algebra (matrix multiplication!) on this stuff.

So what?

Well, here’s the punchline - we can now implement something reasonably complicated - say, solving an L2-regularized linear system using the normal equations - using the above classes, be sure that my code is actually going to run if I feed it data of the right shape, and as a bonus have the compiler confirm for me that my method actually has the right dimensions.

Suppose I want to find the solution to the following problem

\[ \underset{x}{min\,}{ {\|A X - B\|}_2^2 + \lambda \|X\|_2^2} \]

A and B are fixed matrixes (say “data” and “labels” in the case of machine learning.) One way to do this is to take the derivative of the above (convex) expression and set it to 0. This results in the fairly complicated expression:

\[ X = (A^T A + \lambda I)^{-1} A^T B \]

Or, written with my handy Matrix library:

import MatrixUtils._

def solve[X,Y,Z](a: Matrix[X,Y], b: Matrix[X,Z], lambda: Double) = {
  inverse((a.t * a) + ident[Y](a.mat.cols)*lambda) * a.t * b

And what does the type signature of solve look like?

solve: [X, Y, Z](a: Matrix[X,Y], b: Matrix[X,Z], lambda: Double)Matrix[Y,Z]

The compiler has figured out that the result of my solve procedure is an (Y x Z) matrix - which in the specific case of my data is (D x K). If you’re familiar with linear regression, this should seem right!

And to actually use it:

val z = solve(x, y, 1e2)

val predictions = x * z

//Meanwhile, this won't compile:
val z2 = solve(x.t, y, 1e2)

And that’s it. I can be sure that z has the right shape, because the compiler tells me so, and I can be sure that if I had screwed up the dimensions somewhere, I’ll be told at compile time, rather than 30 minutes in to a 2-hour, 100 node job on a cluster.


In this post, I’ve described a problem which, I think, plagues a lot of people who do numerically intensive computing, and proposed a simple solution that relies on the type system to help cope with this problem. Of course, this isn’t going to solve all problems - for example, if the solution to some problem is square, and you forget to transpose it, the compiler can’t catch that for you.

I haven’t yet built this idea into a real system, but I’d be interested in hearing if this idea has already been implemented in scientific or mathematical computing systems, or if not, why people think this is a bad idea.

Find me on Twitter and make fun of me if you have comments!

posted on 2015-05-28

Currency Arbitrage in Python

Yesterday, there was a post on Hacker News about solving a currency arbitrage problem in Prolog. The problem was originally posted by the folks over at Priceonomics. Spoiler alert - I solve their puzzle in this post.

I’ve actually solved this puzzle before, on a final in my undergraduate algorithms class. I remember being proud of myself for coming up with the solution, and for remembering a trick from high school precalculus that allowed me to get there. (Side note: I now see this trick used basically all the time.) Of course, this was for an algorithms class, so I never actually implemented it.

The solution is this: structure the currency network as a directed graph, with exchange rates on the edges. Then, find a cycle where multiplying the weights together gives you a number > 1. Two things to note - first, this is like solving a longest path problem. Second, it’s not quite that because longest paths are about additive weight and we’re looking for multiplicative weight.

The trick I mentioned above involves taking the log of the exchange rates and finding an increasing cycle. We have to remember that the sum of logs of some terms is the same as the log of the product of those terms. To suit the algorithm we use, taking the negative of the log of the weights and finding a negative weight cycle. It turns out that the Bellman-Ford algorithm will do just this.

Originally, I had plans to redo the author’s Prolog solution with the Bellman-Ford algorithm. While I’m a big fan of declarative programming and the first order logic, I chickened out and decided to redo things in good old python.

The goal here was to write in concise, idiomatic python with enough structure to be readable, and use common python libraries as much as possible. So, I use urllib2 and json to grab and parse data from Priceonomics, and NetworkX to structure the data as a graph and run the Bellman-Ford algorithm.

Of course, it turns out that by default NetworkX raises an exception when a negative cycle is detected, so I had to modify it to return the results of the algorithm when a cycle is detected.

Once I did that, though - I wound up with 50 lines of pretty simple python. Most of the works is actually done in interpreting the results of the Bellman-Ford algorithm.

posted on 2013-06-08


To keep up appearances, I’ve decided to start blogging. I’m using Jekyll-Bootstrap and Github Pages for now (that’s right, full hipster). All opinions on this blog are my own and do not necessarily reflect those of my employer.

posted on 2013-01-02