Shape-shifting arrays

We continue our plan to implement a J intepreter. Our next obstacle is the multi-dimensional arrays. One challenge is that J arrays are shape-polymorphic; we may for example immediately change our 2x3x4 array to a 3x8 array with the same elements. Furthermore, J arrays must be regular; if an intermediate result is a ragged array, it must be made regular before proceeding.

A second challenge is that functions defined for arrays of low rank must be automatically changed to functions that work on arrays of any rank. For example, addition is defined for two 0-dimensional arrays, that is, single numbers, and we must somehow upgrade it to work on two arrays of any rank, and possibly of different rank.

Shape polymorphism implies we should store the elements in a one-dimensional array, and hold the actual dimensions in a second one-dimensional array. We use Data.Vector to hold the elements, and plain Haskell lists for the dimensions. For example, in a three-dimensional [3, 4, 2] array, the element at [i, j, k] corresponds to the element of index sum [4*2*i, 2*j, k] in the vector.

Displaying our arrays is a good place to start. For one-, two-, and three- dimensional arrays, J prints the following:

   i.2
0 1
   i.2 3
0 1 2
3 4 5
   i.2 3 4
 0  1  2  3
 4  5  6  7
 8  9 10 11

12 13 14 15
16 17 18 19
20 21 22 23

Experimentation shows that for higher dimensions, J simply adds more blank lines after iterating through each dimension.

We can achieve this with a few lines, though we ignore alignment issues.

module Shaped (Shaped(..), fromList, shapeList, singleton, go1, go2, homogenize) where
import Data.List
import qualified Data.Vector as V
import Data.Vector (Vector, (!))

data Shaped a = Shaped [Int] (Vector a) deriving Eq

showK xs shapeX k =
  case shapeX of
    []     -> show $ xs!k
    [n]    -> unwords $ showK xs [] <$> (k +) <$> [0..n-1]
    (n:ns) -> unlines $ showK xs ns <$> (k +) <$> (product ns *) <$> [0..n-1]

instance Show a => Show (Shaped a) where
  show (Shaped shapeX xs) = showK xs shapeX 0

-- Construction tools.
singleton :: a -> Shaped a
singleton x = Shaped [] (V.singleton x)

fromList :: [a] -> Shaped a
fromList xs = Shaped [length xs] (V.fromList xs)

shapeList :: [Int] -> [a] -> Shaped a
shapeList shape xs =
  Shaped shape (V.fromList $ take (product shape) (cycle xs))

Behold!

$ ghc shaped.hs
*Shaped> shapeList [2,3,4] [0..]
0 1 2 3
4 5 6 7
8 9 10 11

12 13 14 15
16 17 18 19
20 21 22 23

Next, we tackle array fills: we wish to expand a given multidimensional array, by extending any of its dimensions or adding more dimensions, with new entries initialized to a given element.

Fill will exclusively be used by the homogenize function which we will later write. We’ll find that homogenize only needs the vector of elements, rather than a Shaped array.

fill :: a -> [Int] -> Shaped a -> Vector a
fill z newRank (Shaped rank vs)
  | newRank == rank = vs
  | otherwise       = V.replicate (product newRank) z V.//
    zip (sum . zipWith (*) (scanl1 (*) (1:reverse newRank)) . reverse <$>
    sequence (flip take [0..] <$> rank)) (V.toList vs)

It seems to work:

*Shaped> shapeList [2,3] [0..5]
0 1 2
3 4 5

*Shaped> fill 0 [2,4,3] $ shapeList [2,3] [1..]
[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]

Now for the promised homogenize, which takes a fill value, a list of Shaped arrays, the dimensions of a surrounding frame (to use J parlance) and produces a regular array just large enough to accommodate all the input arrays within the given frame.

homogenize :: a -> [Int] -> [Shaped a] -> Shaped a
homogenize z frame xs = let
  origs = map (\(Shaped rs _) -> rs) xs
  m = maximum $ length <$> origs
  exts = map (\xs -> replicate (m - length xs) 1 ++ xs) origs
  resultRank = foldl1' (zipWith max) exts
  in Shaped (frame ++ resultRank) $ V.concat $ fill z resultRank <$> xs

This may become clearer once we move on to automatically changing the ranks of functions. The following example homogenizes a 2x2 array, a 3x3x3 array, and a one-dimensional array of size 5, all sitting in a one-dimensional frame (of size 3):

*Shaped> homogenize 0 [3] [shapeList [2,2] [1..], shapeList [3,3,3] [10..], shapeList [5] [20..]]
1 2 0 0 0
3 4 0 0 0
0 0 0 0 0

0 0 0 0 0
0 0 0 0 0
0 0 0 0 0

0 0 0 0 0
0 0 0 0 0
0 0 0 0 0


10 11 12 0 0
13 14 15 0 0
16 17 18 0 0

19 20 21 0 0
22 23 24 0 0
25 26 27 0 0

28 29 30 0 0
31 32 33 0 0
34 35 36 0 0


20 21 22 23 24
0 0 0 0 0
0 0 0 0 0

0 0 0 0 0
0 0 0 0 0
0 0 0 0 0

0 0 0 0 0
0 0 0 0 0
0 0 0 0 0

As expected, we wind up with a 3x3x3x5 array.

Rank Polymorphism

In Haskell, we code differently when incrementing an integer, incrementing each integer in a list, and incrementing each integer in a list of lists:

> 1 + 2
3
> map (1+) [2..5]
[3,4,5,6]
> map (map (1+)) [[2..5], [4..9]]
[[3,4,5,6],[5,6,7,8,9,10]]

In J, it’s all the same for arrays of all shapes and sizes:

   1 + 2
3
   1 + 2 3 4 5
3 4 5 6
   1 + i. 2 3
1 2 3
4 5 6

Loosely speaking, J takes the increment function, then automatically applies the map function the right number of times so ultimately the integers in the innermost list are incremented.

Naturally we might wonder about functions that make sense at multiple levels. For example, const 'x' can be applied to a list, or to elements inside a list:

> const 'x' [[2..5], [4..9]]
'x'
> map (const 'x') [[2..5], [4..9]]
"xx"
> map (map (const 'x')) [[2..5], [4..9]]
["xxxx","xxxxxx"]

J has this covered too. We can specify exactly what level a function applies:

   (9:"2) i. 2 3
9
   (9:"1) i. 2 3
9 9
   (9:"0) i. 2 3
9 9 9
9 9 9

With a little thought, it becomes apparent how J works. With a little more, we can apply the same trick to binary operators (okay; dyads). Or we can just look up verb ranks and frames.

-- Takes a fill value, rank, function, and input Shaped array.
-- Runs the function using the given rank (or the input array rank, whichever
-- is lower), using the given fill value if needed.
go1 :: a -> Int -> (Shaped a -> Shaped a) -> Shaped a -> Shaped a
go1 z mv v (Shaped shape xs) = homogenize z frame
  [v (Shaped rank $ V.slice (i*sz) sz xs) | i <- [0..product frame-1]]
  where
    (frame, rank) = splitAt (length shape - min mv (length shape)) shape
    sz = product rank

-- Two-argument variant of the above.
-- Takes a fill value, left and right rank, function, and two input Shaped
-- arrays.
go2 :: a -> Int -> Int -> (Shaped a -> Shaped a -> Shaped a) ->
  Shaped a -> Shaped a -> Shaped a
go2 z lv rv v (Shaped shapeX xs) (Shaped shapeY ys)
  | or $ zipWith (/=) frameX frameY = error "frame mismatch"
  | length frameX > length frameY =
    f (flip v) (frameY, rankY) ys (frameX, rankX) xs
  | otherwise =
    f        v (frameX, rankX) xs (frameY, rankY) ys
  where
    dimL = length shapeX - min lv (length shapeX)
    dimR = length shapeY - min rv (length shapeY)
    (frameX, rankX) = splitAt dimL shapeX
    (frameY, rankY) = splitAt dimR shapeY

    f v (frameX, rankX) xs (frameY, rankY) ys = homogenize z frameY $
       concat [[v (Shaped rankX $ V.slice (i*xsize) xsize xs)
                  (Shaped rankY $ V.slice ((i*m + j)*ysize) ysize ys)
                | j <- [0..m-1] ] | i <- [0..product frameX-1]]
      where
        xsize = product rankX
        ysize = product rankY
        m = div (V.length ys * xsize) (V.length xs * ysize)

By design, if a J verb has rank n, then it is defined for every rank up to and including n. This fits in with J’s automatic extension of verbs to any rank: not only can we omit the equivalent of Haskell’s map, but we can also omit calls to functions to produce singleton lists and the like.

To test special cases of the above, we add a couple of helpers:

test1 :: (a -> a) -> Shaped a -> Shaped a
test1 f (Shaped [] xs) = singleton (f $ xs!0)

test2 :: (a -> a -> a) -> Shaped a -> Shaped a -> Shaped a
test2 f (Shaped [] xs) (Shaped [] ys) = singleton (f (xs!0) (ys!0))

Then:

*Shaped> go1 0 0 (test1 (^2)) $ shapeList [2, 3, 4] [1..]
1 4 9 16
25 36 49 64
81 100 121 144

169 196 225 256
289 324 361 400
441 484 529 576


*Shaped> go2 0 0 0 (test2 (+)) (shapeList [4,2] [1..]) (shapeList [4,2,5] [10..])
11 12 13 14 15
17 18 19 20 21

23 24 25 26 27
29 30 31 32 33

35 36 37 38 39
41 42 43 44 45

47 48 49 50 51
53 54 55 56 57

Ideally we should test ranks higher than 0 as well, but we’ll make do with indirect tests when we write our J interpreter. The next hurdle is understanding J’s numeric types.


Ben Lynn blynn@cs.stanford.edu 💡