InferOpt.jl:
Combinatorial Optimization-enhanced Machine Learning

Léo Baty

École des Ponts, CERMICS

2024-07-12

InferOpt.jl

  • State-of-the-art tools to incorporate combinatorial optimization algorithms in machine learning pipelines
  • Compatible with ML and AD ChainRules Julia ecosystem
  • Part of the new JuliaDecisionFocusedLearning GitHub organization Alt text

Why?

  • Increase the expressivity of machine learning models, by having combinatorial outputs
  • Leverage algorithms for “easy” problems to solve harder ones

Difficulty: combinatorial algorithms are piecewise constant functions \(\implies\) no informative gradients

  • InferOpt provides differentiable layers and loss functions to overcome this issue.

Path finding on Warcraft maps

  • Input: map image
  • Goal: find the shortest path from top left to bottom right
  • True cell costs are unknown

Dataset: Set of (image, path) pairs to imitate

Retrieving the dataset

using InferOptBenchmarks.Warcraft
b = WarcraftBenchmark();

Download and format the data:

dataset = generate_dataset(b, 50)
train_dataset, test_dataset = dataset[1:45], dataset[46:50]
x, y_true, θ_true = test_dataset[1]
plot_data(x, y_true, θ_true)

The neural network

First three layers of a resnet

model = generate_statistical_model(b)
Chain(
  Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
  BatchNorm(64, relu),                  # 128 parameters, plus 128
  MaxPool((3, 3), pad=1, stride=2),
  Parallel(
    addact(NNlib.relu, ...),
    identity,
    Chain(
      Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
      BatchNorm(64),                    # 128 parameters, plus 128
      NNlib.relu,
      Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
      BatchNorm(64),                    # 128 parameters, plus 128
    ),
  ),
  AdaptiveMaxPool((12, 12)),
  InferOptBenchmarks.Warcraft.average_tensor,
  InferOptBenchmarks.Warcraft.neg_tensor,
  InferOptBenchmarks.Warcraft.squeeze_last_dims,
)         # Total: 9 trainable arrays, 83_520 parameters,
          # plus 6 non-trainable, 384 parameters, summarysize 330.312 KiB.

Predicted costs

using Plots
θ = model(x)

kw = (; framestyle=:none, yflip=true, aspect_ratio=:equal, legend=false, size=(300, 300))
clim=(minimum(θ_true), maximum(θ_true))
heatmap(-θ; kw..., clim)

Combinatorial algorithm

We use the Dijkstra algorithm, wrapped from Graphs.jl

maximizer = generate_maximizer(b)
dijkstra_maximizer (generic function with 1 method)

Output of untrained pipeline

heatmap(maximizer(θ); kw...)

Path we want to output

heatmap(maximizer(-θ_true); kw...)

Computing derivatives (Zygote)

Either fails…

using Zygote

Zygote.jacobian(maximizer, θ)
ErrorException: ErrorException("Mutating arrays is not supported -- called setindex!(Matrix{Int64}, ...)\nThis error occurs when you ask Zygote to differentiate operations that change\nthe elements of arrays in place (e.g. setting values with x .= ...)\n\nPossible fixes:\n- avoid mutating operations (preferred)\n- or read the documentation and solutions for this error\n  https://fluxml.ai/Zygote.jl/latest/limitations\n")
Mutating arrays is not supported -- called setindex!(Matrix{Int64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Int64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:70
  [3] (::Zygote.var"#539#540"{Matrix{Int64}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Matrix{Int64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] path_to_matrix
    @ ~/.julia/packages/InferOptBenchmarks/zB6U9/src/Warcraft/grid_graph.jl:121 [inlined]
  [6] (::Zygote.Pullback{Tuple{typeof(InferOptBenchmarks.Warcraft.path_to_matrix), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Vector{Int64}}, Any})(Δ::Matrix{Int64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [7] #dijkstra_maximizer#12
    @ ~/.julia/packages/InferOptBenchmarks/zB6U9/src/Warcraft/grid_graph.jl:289 [inlined]
  [8] (::Zygote.Pullback{Tuple{InferOptBenchmarks.Warcraft.var"##dijkstra_maximizer#12", @Kwargs{}, typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Graphs.dijkstra_shortest_paths), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Graphs.dijkstra_shortest_paths), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Int64, LinearAlgebra.Adjoint{Float32, SparseArrays.SparseMatrixCSC{Float32, Int64}}}, Tuple{Zygote.Pullback{Tuple{Graphs.var"##dijkstra_shortest_paths#142", Bool, Bool, Float32, typeof(Graphs.dijkstra_shortest_paths), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Int64, LinearAlgebra.Adjoint{Float32, SparseArrays.SparseMatrixCSC{Float32, Int64}}}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:allpaths, :trackvertices, :maxdist)}}, Tuple{Bool, Bool, Float32}}, Tuple{Zygote.var"#2220#back#315"{Zygote.Jnew{@NamedTuple{allpaths::Bool, trackvertices::Bool, maxdist::Float32}, Nothing, true}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#481"{Tuple{ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{allpaths::Bool, trackvertices::Bool, maxdist::Float32}, typeof(Graphs.dijkstra_shortest_paths), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Vector{Int64}, LinearAlgebra.Adjoint{Float32, SparseArrays.SparseMatrixCSC{Float32, Int64}}}, Any}}}, Zygote.Pullback{Tuple{typeof(eltype), LinearAlgebra.Adjoint{Float32, SparseArrays.SparseMatrixCSC{Float32, Int64}}}, Tuple{Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}}}, Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}}}, Zygote.Pullback{Tuple{typeof(Graphs.weights), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#adjoint_mat_pullback#1127"{ChainRulesCore.ProjectTo{SparseArrays.SparseMatrixCSC, @NamedTuple{element::ChainRulesCore.ProjectTo{Float32, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, rowval::Vector{Int64}, nzranges::Vector{UnitRange{Int64}}, colptr::Vector{Int64}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:weights, Zygote.Context{false}, SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, SparseArrays.SparseMatrixCSC{Float32, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(InferOptBenchmarks.Warcraft.path_to_matrix), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Vector{Int64}}, Any}, Zygote.var"#3615#back#1100"{Zygote.var"#1096#1099"}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:parents, Zygote.Context{false}, Graphs.DijkstraState{Float32, Int64}, Vector{Int64}}}, Zygote.Pullback{Tuple{typeof(Graphs.nv), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}}, Tuple{Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.Pullback{Tuple{typeof(Graphs.weights), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#adjoint_mat_pullback#1127"{ChainRulesCore.ProjectTo{SparseArrays.SparseMatrixCSC, @NamedTuple{element::ChainRulesCore.ProjectTo{Float32, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, rowval::Vector{Int64}, nzranges::Vector{UnitRange{Int64}}, colptr::Vector{Int64}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:weights, Zygote.Context{false}, SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, SparseArrays.SparseMatrixCSC{Float32, Int64}}}}}, Zygote.ZBack{Zygote.var"#IntX_pullback#335"}}}, Zygote.Pullback{Tuple{typeof(InferOptBenchmarks.Warcraft.get_path), Vector{Int64}, Int64, Int64}, Any}, Zygote.Pullback{Tuple{typeof(warcraft_grid_graph), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{InferOptBenchmarks.Warcraft.var"##warcraft_grid_graph#1", Bool, typeof(warcraft_grid_graph), Matrix{Float32}}, Any}}}}})(Δ::Matrix{Int64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [9] dijkstra_maximizer
    @ ~/.julia/packages/InferOptBenchmarks/zB6U9/src/Warcraft/grid_graph.jl:285 [inlined]
 [10] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [11] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer), Matrix{Float32}}, Tuple{Zygote.var"#2366#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), @NamedTuple{}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.Pullback{Tuple{InferOptBenchmarks.Warcraft.var"##dijkstra_maximizer#12", @Kwargs{}, typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Graphs.dijkstra_shortest_paths), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Graphs.dijkstra_shortest_paths), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Int64, LinearAlgebra.Adjoint{Float32, SparseArrays.SparseMatrixCSC{Float32, Int64}}}, Tuple{Zygote.Pullback{Tuple{Graphs.var"##dijkstra_shortest_paths#142", Bool, Bool, Float32, typeof(Graphs.dijkstra_shortest_paths), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Int64, LinearAlgebra.Adjoint{Float32, SparseArrays.SparseMatrixCSC{Float32, Int64}}}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:allpaths, :trackvertices, :maxdist)}}, Tuple{Bool, Bool, Float32}}, Tuple{Zygote.var"#2220#back#315"{Zygote.Jnew{@NamedTuple{allpaths::Bool, trackvertices::Bool, maxdist::Float32}, Nothing, true}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#481"{Tuple{ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}, Tuple{Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{allpaths::Bool, trackvertices::Bool, maxdist::Float32}, typeof(Graphs.dijkstra_shortest_paths), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Vector{Int64}, LinearAlgebra.Adjoint{Float32, SparseArrays.SparseMatrixCSC{Float32, Int64}}}, Any}}}, Zygote.Pullback{Tuple{typeof(eltype), LinearAlgebra.Adjoint{Float32, SparseArrays.SparseMatrixCSC{Float32, Int64}}}, Tuple{Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}}}, Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}}}, Zygote.Pullback{Tuple{typeof(Graphs.weights), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#adjoint_mat_pullback#1127"{ChainRulesCore.ProjectTo{SparseArrays.SparseMatrixCSC, @NamedTuple{element::ChainRulesCore.ProjectTo{Float32, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, rowval::Vector{Int64}, nzranges::Vector{UnitRange{Int64}}, colptr::Vector{Int64}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:weights, Zygote.Context{false}, SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, SparseArrays.SparseMatrixCSC{Float32, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(InferOptBenchmarks.Warcraft.path_to_matrix), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, Vector{Int64}}, Any}, Zygote.var"#3615#back#1100"{Zygote.var"#1096#1099"}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:parents, Zygote.Context{false}, Graphs.DijkstraState{Float32, Int64}, Vector{Int64}}}, Zygote.Pullback{Tuple{typeof(Graphs.nv), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}}, Tuple{Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.Pullback{Tuple{typeof(Graphs.weights), SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}}, Tuple{Zygote.ZBack{ChainRules.var"#adjoint_mat_pullback#1127"{ChainRulesCore.ProjectTo{SparseArrays.SparseMatrixCSC, @NamedTuple{element::ChainRulesCore.ProjectTo{Float32, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, rowval::Vector{Int64}, nzranges::Vector{UnitRange{Int64}}, colptr::Vector{Int64}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:weights, Zygote.Context{false}, SimpleWeightedGraphs.SimpleWeightedDiGraph{Int64, Float32}, SparseArrays.SparseMatrixCSC{Float32, Int64}}}}}, Zygote.ZBack{Zygote.var"#IntX_pullback#335"}}}, Zygote.Pullback{Tuple{typeof(InferOptBenchmarks.Warcraft.get_path), Vector{Int64}, Int64, Int64}, Any}, Zygote.Pullback{Tuple{typeof(warcraft_grid_graph), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{InferOptBenchmarks.Warcraft.var"##warcraft_grid_graph#1", Bool, typeof(warcraft_grid_graph), Matrix{Float32}}, Any}}}}}}}}})(Δ::Matrix{Int64})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [12] call_composed
    @ ./operators.jl:1045 [inlined]
 [13] (::Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any})(Δ::Matrix{Int64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [14] call_composed
    @ ./operators.jl:1044 [inlined]
 [15] #_#103
    @ ./operators.jl:1041 [inlined]
 [16] (::Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}}}, Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, typeof(Zygote._jvec)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Matrix{Float32}}, @Kwargs{}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Matrix{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Int64}}, Tuple{Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Int64}, Tuple{Int64}}}, Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}}}}}, Zygote.var"#2141#back#281"{Zygote.var"#277#280"}}}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [17] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [18] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [19] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [20] (::Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Matrix{Float32}}, Tuple{Zygote.var"#2366#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), @NamedTuple{}}}, Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}}}, Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, typeof(Zygote._jvec)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Matrix{Float32}}, @Kwargs{}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Matrix{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Int64}}, Tuple{Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Int64}, Tuple{Int64}}}, Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}}}}}, Zygote.var"#2141#back#281"{Zygote.var"#277#280"}}}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Matrix{Float32}}, Tuple{Zygote.var"#2366#back#419"{Zygote.var"#pairs_namedtuple_pullback#418"{(), @NamedTuple{}}}, Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{Base.var"##_#103", @Kwargs{}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(Base.unwrap_composed), typeof(Zygote._jvec)}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.maybeconstructor), typeof(Zygote._jvec)}, Tuple{}}}}, Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2013#back#204"{typeof(identity)}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:outer, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, typeof(Zygote._jvec)}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:inner, Zygote.Context{false}, ComposedFunction{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}}}}, Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(Zygote._jvec), typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Matrix{Float32}}, @Kwargs{}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.call_composed), Tuple{typeof(InferOptBenchmarks.Warcraft.dijkstra_maximizer)}, Tuple{Matrix{Float32}}, @Kwargs{}}, Any}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, typeof(Zygote._jvec)}}, Zygote.Pullback{Tuple{typeof(Zygote._jvec), Matrix{Int64}}, Tuple{Zygote.Pullback{Tuple{typeof(vec), Matrix{Int64}}, Tuple{Zygote.var"#2763#back#609"{Zygote.var"#603#607"{Matrix{Int64}, Tuple{Int64}}}, Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}}}}}, Zygote.var"#2141#back#281"{Zygote.var"#277#280"}}}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2013#back#204"{typeof(identity)}}}})(Δ::Vector{Int64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [22] withjacobian(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/grad.jl:150
 [23] jacobian(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/grad.jl:128
 [24] top-level scope
    @ ~/work/JuliaCon2024-InferOpt/JuliaCon2024-InferOpt/main.qmd:145

Computing derivatives (ForwardDiff)

… or is zero almost everywhere

using ForwardDiff

g = ForwardDiff.jacobian(maximizer, θ)
heatmap(g; kw...)
any(g .!= 0.0)
false

Regularizing the maximizer

using InferOpt

perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100)
yp = perturbed_maximizer(-θ_true)
12×12 Matrix{Float64}:
 1.0   0.0   0.0   0.0  0.0   0.0   0.0   0.0   0.0  0.0  0.0   0.0
 0.48  0.52  0.0   0.0  0.0   0.0   0.0   0.0   0.0  0.0  0.0   0.0
 0.38  0.47  0.15  0.0  0.0   0.0   0.0   0.0   0.0  0.0  0.0   0.0
 0.53  0.43  0.04  0.0  0.0   0.0   0.0   0.0   0.0  0.0  0.0   0.0
 0.01  0.95  0.04  0.0  0.0   0.0   0.0   0.0   0.0  0.0  0.0   0.0
 0.0   0.01  0.99  0.0  0.0   0.0   0.0   0.0   0.0  0.0  0.0   0.0
 0.0   0.0   0.01  1.0  0.0   0.0   0.0   0.0   0.0  0.0  0.0   0.0
 0.0   0.0   0.0   0.0  1.0   0.0   0.0   0.0   0.0  0.0  0.0   0.0
 0.0   0.0   0.0   0.0  0.02  0.98  0.0   0.0   0.0  0.0  0.0   0.0
 0.0   0.0   0.0   0.0  0.0   0.02  0.98  0.0   0.0  0.0  0.0   0.0
 0.0   0.0   0.0   0.0  0.0   0.0   0.02  0.98  0.0  0.0  0.44  0.0
 0.0   0.0   0.0   0.0  0.0   0.0   0.0   0.02  1.0  1.0  0.56  1.0
heatmap(yp; kw..., legend=true)

It is now differentiable!

Thanks to custom backward rules

Zygote.jacobian(perturbed_maximizer, θ)[1]
144×144 Matrix{Float64}:
  0.0        0.0        0.0         …   0.0         0.0         0.0
 -0.187746   0.511398   0.301636        0.0128768   0.108458   -0.32656
 -0.094079   0.0432884  0.140508        0.00417465  0.0693126  -0.0824386
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0         …   0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  ⋮                                 ⋱                          
  0.0        0.0        0.0         …   0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.0        0.0        0.0         …   0.0         0.0         0.0
  0.0        0.0        0.0             0.0         0.0         0.0
  0.390815  -0.068813   0.00554018     -0.2066      0.732822    0.239241
  0.0        0.0        0.0             0.0         0.0         0.0

Allows defining a differentiable loss

For supervised learning:

loss = FenchelYoungLoss(perturbed_maximizer)
loss(θ, y_true)
2.502653261239063

Gradients are defined:

heatmap(Zygote.gradient(t -> loss(t, y_true), θ)[1]; kw..., legend=true)

Training

Usual Flux training loop

using Flux

opt_state = Flux.setup(Adam(1e-3), model)
loss_history = Float64[]
for epoch in 1:50
    val, grads = Flux.withgradient(model) do m
        sum(loss(m(x), y) for (x, y, _) in train_dataset) / length(train_dataset)
    end
    Flux.update!(opt_state, model, grads[1])
    push!(loss_history, val)
end
plot(loss_history)

Prediction

Predicted costs and path

(x, y_true, θ_true) = test_dataset[1]
θ = model(x)
y = UInt8.(maximizer(θ))
plot_data(x, y, θ; θ_true)

Thank you!