Last active
September 26, 2025 17:41
-
-
Save axionbuster/e94a7ba66ed3b0e23d9b696a48ffeba5 to your computer and use it in GitHub Desktop.
The SMAWK algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // SMAWK algorithm; similar to inter.zig, which only implements | |
| // INTERPOLATE. due to my poor Zig skills, I had to duplicate some code. | |
| const std = @import("std"); | |
| // shared by interpolate and smawk | |
| const Go = struct { | |
| // actual row index begins at offset, and goes up by scale | |
| offset: usize, | |
| scale: usize, | |
| // width and height are measured in actual indices | |
| width: usize, | |
| height: usize, | |
| // locations of the minima of rows, real indices | |
| minima: []usize, | |
| // SMAWK columns (maps virtual columns to to real column indices) | |
| columns: ?[]usize, | |
| // 0-based indexing | |
| value_: *const fn (usize, usize) i32, | |
| alloc: std.mem.Allocator, | |
| // by convention, i is the row index in the reduced matrix | |
| // and j is the column index in the original matrix | |
| fn rowOf(this: *const Go, i: usize) usize { | |
| return i * this.scale + this.offset; | |
| } | |
| fn rowInBounds(this: *const Go, i: usize) bool { | |
| return this.rowOf(i) < this.height; | |
| } | |
| fn virtualToRealColumn(this: *const Go, j: usize) usize { | |
| if (this.columns) |cols| { | |
| return cols[j]; | |
| } else { | |
| return j; | |
| } | |
| } | |
| fn virtualWidth(this: *const Go) usize { | |
| return if (this.columns) |cs| cs.len else this.width; | |
| } | |
| fn value(this: *const Go, i: usize, j: usize) i32 { | |
| return this.value_(this.rowOf(i), this.virtualToRealColumn(j)); | |
| } | |
| // inclusively end the search here | |
| fn minimumIxOfReal(this: *const Go, i: usize) usize { | |
| if (this.rowInBounds(i)) { | |
| return this.minima[this.rowOf(i)]; | |
| } else if (this.columns) |cs| { | |
| if (cs.len == 0) { | |
| std.debug.panic("minimumIxOfReal: no columns", .{}); | |
| } else { | |
| return cs[cs.len - 1]; // out of bounds | |
| } | |
| } else { | |
| return this.width - 1; // out of bounds | |
| } | |
| } | |
| fn setMinimumIxOf(this: *Go, i: usize, j: usize) void { | |
| if (this.rowInBounds(i)) { | |
| this.minima[this.rowOf(i)] = this.virtualToRealColumn(j); | |
| } else { | |
| std.debug.panic( | |
| "setMinimumIxOf: row ({d} -> {d}) is outside of the bounds", | |
| .{ i, this.rowOf(i) }, | |
| ); | |
| } | |
| } | |
| fn virtualHeight(this: *const Go) usize { | |
| // ceiling division of (height - offset) by scale. | |
| return (this.height - this.offset + this.scale - 1) / this.scale; | |
| } | |
| fn interpolate(this: *Go) std.mem.Allocator.Error!void { | |
| if (this.rowInBounds(1)) { | |
| // recursively find the minima of the odd rows. | |
| // somewhat surprisingly, this line of code is all we need. | |
| var go2 = this.*; | |
| go2.offset = this.offset + this.scale; | |
| go2.scale = this.scale * 2; | |
| try go2.reduce(); | |
| } | |
| // now the minima of the even rows, by interpolating the odd rows | |
| // (hence the name of the procedure). | |
| var i: usize = 0; | |
| var j: usize = 0; | |
| // note: width >= 1. | |
| while (this.rowInBounds(i)) { | |
| const limitReal = this.minimumIxOfReal(i + 1); | |
| var minv = this.value(i, j); | |
| var minj = j; | |
| while (j < this.virtualWidth() and this.virtualToRealColumn(j) < limitReal) { | |
| j += 1; | |
| const v = this.value(i, j); | |
| // since we're finding the leftmost minimum, use '<', not '<='. | |
| if (v < minv) { | |
| minv = v; | |
| minj = j; | |
| } | |
| } | |
| this.setMinimumIxOf(i, minj); | |
| j = minj; // start from the last minimum | |
| i += 2; | |
| } | |
| } | |
| fn reduce(this: *Go) std.mem.Allocator.Error!void { | |
| const width = this.virtualWidth(); | |
| if (width <= this.virtualHeight() or this.virtualHeight() <= 1) { | |
| return this.interpolate(); | |
| } | |
| // implicit stack: cols, t. | |
| // cols stores the virtual columns | |
| const cols = try this.alloc.alloc(usize, this.virtualHeight()); | |
| errdefer this.alloc.free(cols); | |
| var t: usize = 0; // stack size | |
| // real iteration | |
| var k: usize = 0; | |
| // cols must store the actual column indices | |
| cols[0] = 0; | |
| while (k < width) : (k += 1) { | |
| while (t > 0 and this.value(t - 1, cols[t - 1]) > this.value(t - 1, k)) { | |
| t -= 1; | |
| } | |
| if (t < this.virtualHeight()) { | |
| cols[t] = k; | |
| t += 1; | |
| } | |
| } | |
| // map to real indices | |
| var real = try this.alloc.alloc(usize, t); | |
| defer this.alloc.free(real); | |
| var i: usize = 0; | |
| while (i < t) : (i += 1) { | |
| real[i] = this.virtualToRealColumn(cols[i]); | |
| } | |
| this.alloc.free(cols); | |
| var go2 = this.*; | |
| go2.columns = real; | |
| try go2.interpolate(); | |
| } | |
| }; | |
| pub fn interpolate( | |
| alloc: std.mem.Allocator, | |
| width: usize, | |
| height: usize, | |
| value: fn (usize, usize) i32, | |
| ) ![]usize { | |
| if (width <= 0 or height <= 0) { | |
| return try alloc.alloc(usize, 0); | |
| } | |
| const minima = try alloc.alloc(usize, height); | |
| errdefer alloc.free(minima); | |
| var go = Go{ | |
| .alloc = alloc, | |
| .offset = 0, | |
| .width = width, | |
| .height = height, | |
| .scale = 1, | |
| .minima = minima, | |
| .columns = null, | |
| .value_ = value, | |
| }; | |
| try go.reduce(); | |
| return minima; | |
| } | |
| test "interpolate 1 (smawk)" { | |
| const allocator = std.testing.allocator; | |
| const M = [5][5]i32{ | |
| .{ 12, 21, 38, 76, 27 }, | |
| .{ 74, 14, 14, 29, 60 }, | |
| .{ 21, 8, 25, 10, 71 }, | |
| .{ 68, 45, 29, 15, 76 }, | |
| .{ 97, 8, 12, 2, 6 }, | |
| }; | |
| const Go2 = struct { | |
| fn value(i: usize, j: usize) i32 { | |
| return M[i][j]; | |
| } | |
| }; | |
| const minima = try interpolate(allocator, 5, 5, Go2.value); | |
| defer allocator.free(minima); | |
| try std.testing.expectEqual(0, minima[0]); | |
| try std.testing.expectEqual(1, minima[1]); | |
| try std.testing.expectEqual(1, minima[2]); | |
| try std.testing.expectEqual(3, minima[3]); | |
| try std.testing.expectEqual(3, minima[4]); | |
| } | |
| test "interpolate 2 (smawk)" { | |
| const allocator = std.testing.allocator; | |
| const M = [9][18]i32{ | |
| .{ 25, 21, 13, 10, 20, 13, 19, 35, 37, 41, 58, 66, 82, 99, 124, 133, 156, 178 }, | |
| .{ 42, 35, 26, 20, 29, 21, 25, 37, 36, 39, 56, 64, 76, 91, 116, 125, 146, 164 }, | |
| .{ 57, 48, 35, 28, 33, 24, 28, 40, 37, 37, 54, 61, 72, 83, 107, 113, 131, 146 }, | |
| .{ 78, 65, 51, 42, 44, 35, 38, 48, 42, 42, 55, 61, 70, 80, 100, 106, 120, 135 }, | |
| .{ 90, 76, 58, 48, 49, 39, 42, 48, 39, 35, 47, 51, 56, 63, 80, 86, 97, 110 }, | |
| .{ 103, 85, 67, 56, 55, 44, 44, 49, 39, 33, 41, 44, 49, 56, 71, 75, 84, 96 }, | |
| .{ 123, 105, 86, 75, 73, 59, 57, 62, 51, 44, 50, 52, 55, 59, 72, 74, 80, 92 }, | |
| .{ 142, 123, 100, 86, 82, 65, 61, 62, 50, 43, 47, 45, 46, 46, 58, 59, 65, 73 }, | |
| .{ 151, 130, 104, 88, 80, 59, 52, 49, 37, 29, 29, 24, 23, 20, 28, 25, 31, 39 }, | |
| }; | |
| const Go2 = struct { | |
| fn value(i: usize, j: usize) i32 { | |
| return M[i][j]; | |
| } | |
| }; | |
| const minima = try interpolate(allocator, 18, 9, Go2.value); | |
| try std.testing.expectEqual(9, minima.len); | |
| defer allocator.free(minima); | |
| const expect = [_]usize{ 3, 3, 5, 5, 9, 9, 9, 9, 13 }; | |
| var i: usize = 0; | |
| while (i < minima.len) : (i += 1) { | |
| try std.testing.expectEqual(expect[i], minima[i]); | |
| } | |
| } | |
| // find the row minima of any matrix by brute force | |
| // only for testing. | |
| fn naive_rowmins( | |
| alloc: std.mem.Allocator, | |
| width: usize, | |
| height: usize, | |
| value: fn (usize, usize) i32, | |
| ) ![]usize { | |
| if (width <= 0 or height <= 0) { | |
| return try alloc.alloc(usize, 0); | |
| } | |
| const minima = try alloc.alloc(usize, height); | |
| var i: usize = 0; | |
| while (i < height) : (i += 1) { | |
| var j: usize = 0; | |
| var minj: usize = 0; | |
| var minv = value(i, 0); | |
| while (j + 1 < width) : (j += 1) { | |
| const v = value(i, j + 1); | |
| if (v < minv) { | |
| minv = v; | |
| minj = j + 1; | |
| } | |
| } | |
| minima[i] = minj; | |
| } | |
| return minima; | |
| } | |
| // A randomized Monge generator composed from the lemma: | |
| // - (a) row-constant component A[i,j] += r[i] | |
| // - (b) column-constant component A[i,j] += c[j] | |
| // - (c) upper-right rectangle of 1s (with weight) added | |
| // - (d) lower-left rectangle of 1s (with weight) added | |
| // - (e) positive multiple (mult >= 1) | |
| // - (f) sum (we sum all chosen components) | |
| // - (g) transpose (by swapping i/j in evaluation) | |
| // | |
| // We evaluate on-the-fly (no full matrix allocation). | |
| const MongeCase = struct { | |
| // final matrix dimensions | |
| width: usize, | |
| height: usize, | |
| // base dimensions (if transposed, base dims are swapped) | |
| base_width: usize, | |
| base_height: usize, | |
| // components | |
| row_vals: []i32, // length base_height, or empty slice if unused | |
| col_vals: []i32, // length base_width, or empty slice if unused | |
| // upper-right rectangle in base coords: rows [0 .. ur_top-1], cols [base_width-ur_right .. base_width-1] | |
| has_ur: bool, | |
| ur_top: usize, | |
| ur_right: usize, | |
| ur_weight: i32, | |
| // lower-left rectangle in base coords: rows [base_height-ll_bottom .. base_height-1], cols [0 .. ll_left-1] | |
| has_ll: bool, | |
| ll_bottom: usize, | |
| ll_left: usize, | |
| ll_weight: i32, | |
| // positive scaling | |
| mult: i32, | |
| // whether final matrix is the transpose of the base expression | |
| transposed: bool, | |
| }; | |
| // Global pointer to the current case for the value function. | |
| var g_case: ?*const MongeCase = null; | |
| // Evaluate MongeCase at (i, j) in final matrix coords. | |
| fn monge_value(i: usize, j: usize) i32 { | |
| const c = g_case orelse std.debug.panic("monge_value: g_case not set", .{}); | |
| // Convert to base coords depending on transposition. | |
| const bi: usize = if (c.transposed) j else i; | |
| const bj: usize = if (c.transposed) i else j; | |
| var acc: i32 = 0; | |
| if (c.row_vals.len != 0) { | |
| acc += c.row_vals[bi]; | |
| } | |
| if (c.col_vals.len != 0) { | |
| acc += c.col_vals[bj]; | |
| } | |
| if (c.has_ur) { | |
| if (bi < c.ur_top and bj >= c.base_width - c.ur_right) { | |
| acc += c.ur_weight; | |
| } | |
| } | |
| if (c.has_ll) { | |
| if (bi >= c.base_height - c.ll_bottom and bj < c.ll_left) { | |
| acc += c.ll_weight; | |
| } | |
| } | |
| // Positive multiple | |
| return acc * c.mult; | |
| } | |
| // Helpers for RNG | |
| fn rnd_between(r: *std.Random, lo: i32, hi_inclusive: i32) i32 { | |
| const span_u: u32 = @intCast(hi_inclusive - lo + 1); | |
| const v: u32 = r.int(u32) % span_u; | |
| return @as(i32, @intCast(v)) + lo; | |
| } | |
| fn rnd_bool(r: *std.Random) bool { | |
| return (r.int(u1) & 1) == 1; | |
| } | |
| test "interpolate randomized monge (fixed-seed)" { | |
| const allocator = std.testing.allocator; | |
| const rand = std.Random; | |
| var prng = rand.DefaultPrng.init(0xD00DFEEDCAFEBABE); | |
| var rng = prng.random(); | |
| const trials: usize = 200; | |
| var t: usize = 0; | |
| while (t < trials) : (t += 1) { | |
| // Final dims | |
| const width: usize = @intCast(rnd_between(&rng, 1, 25)); | |
| const height: usize = @intCast(rnd_between(&rng, 1, 25)); | |
| // Randomly decide if we apply transposition to the base expression | |
| const transposed = rnd_bool(&rng); | |
| const base_width: usize = if (transposed) height else width; | |
| const base_height: usize = if (transposed) width else height; | |
| // Build components | |
| // Decide to include row-const and/or col-const components | |
| const use_row = rnd_bool(&rng); | |
| const use_col = rnd_bool(&rng); | |
| var row_vals: []i32 = &[_]i32{}; | |
| var col_vals: []i32 = &[_]i32{}; | |
| if (use_row) { | |
| row_vals = try allocator.alloc(i32, base_height); | |
| var i: usize = 0; | |
| while (i < base_height) : (i += 1) { | |
| // small values to avoid overflow | |
| row_vals[@intCast(i)] = rnd_between(&rng, -3, 7); | |
| } | |
| } | |
| if (use_col) { | |
| col_vals = try allocator.alloc(i32, base_width); | |
| var j: usize = 0; | |
| while (j < base_width) : (j += 1) { | |
| col_vals[@intCast(j)] = rnd_between(&rng, -3, 7); | |
| } | |
| } | |
| // Upper-right rectangle | |
| const has_ur = rnd_bool(&rng); | |
| var ur_top: usize = 0; | |
| var ur_right: usize = 0; | |
| var ur_weight: i32 = 0; | |
| if (has_ur) { | |
| ur_top = @intCast(rnd_between(&rng, 1, @intCast(base_height))); | |
| ur_right = @intCast(rnd_between(&rng, 1, @intCast(base_width))); | |
| ur_weight = @intCast(rnd_between(&rng, 0, 5)); // nonnegative | |
| } | |
| // Lower-left rectangle | |
| const has_ll = rnd_bool(&rng); | |
| var ll_bottom: usize = 0; | |
| var ll_left: usize = 0; | |
| var ll_weight: i32 = 0; | |
| if (has_ll) { | |
| ll_bottom = @intCast(rnd_between(&rng, 1, @intCast(base_height))); | |
| ll_left = @intCast(rnd_between(&rng, 1, @intCast(base_width))); | |
| ll_weight = rnd_between(&rng, 0, 5); // nonnegative | |
| } | |
| // positive multiple | |
| const mult: i32 = rnd_between(&rng, 1, 5); | |
| var case = MongeCase{ | |
| .width = width, | |
| .height = height, | |
| .base_width = base_width, | |
| .base_height = base_height, | |
| .row_vals = row_vals, | |
| .col_vals = col_vals, | |
| .has_ur = has_ur, | |
| .ur_top = ur_top, | |
| .ur_right = ur_right, | |
| .ur_weight = ur_weight, | |
| .has_ll = has_ll, | |
| .ll_bottom = ll_bottom, | |
| .ll_left = ll_left, | |
| .ll_weight = ll_weight, | |
| .mult = mult, | |
| .transposed = transposed, | |
| }; | |
| // Run both algorithms on this case | |
| g_case = &case; | |
| const minima_interp = try interpolate(allocator, width, height, monge_value); | |
| defer allocator.free(minima_interp); | |
| const minima_naive = try naive_rowmins(allocator, width, height, monge_value); | |
| defer allocator.free(minima_naive); | |
| // Check equality | |
| var i: usize = 0; | |
| while (i < height) : (i += 1) { | |
| try std.testing.expectEqual( | |
| minima_naive[i], | |
| minima_interp[i], | |
| ); | |
| } | |
| // Check monotonicity of row minima (nondecreasing) | |
| if (height > 1) { | |
| var r: usize = 1; | |
| while (r < height) : (r += 1) { | |
| try std.testing.expect( | |
| minima_interp[@intCast(r - 1)] <= minima_interp[r], | |
| ); | |
| } | |
| } | |
| // cleanup | |
| if (row_vals.len != 0) allocator.free(row_vals); | |
| if (col_vals.len != 0) allocator.free(col_vals); | |
| g_case = null; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment