Staff Scheduling#

Introduction#

In this example, we will explore how to model a staff scheduling problem using the DecisionAI interface. Staff scheduling is a common challenge in many organizations, where the goal is to allocate employees to shifts efficiently. Through this example, you will learn how to define the problem’s inputs, variables, objective, and constraints within the DecisionAI framework.

To try out this base model in the CLI, you can run the following command:

export QUANTAGONIA_API_KEY=<YOUR_API_KEY>

python -m decision_ai.examples.staff_scheduling.chat_example

We start with the mathematical definition of the problem and then show how to implement it using the DecisionAI interface.

Mathematical Model Description#

The staff scheduling model is designed to allocate employees to shifts over a one-week period. The goal is to ensure that all shifts are adequately staffed while minimizing the difference between the maximum and minimum number of shifts assigned to any employee.

Inputs#

  • \(E\): The set of employees.

  • \(S\): The set of shifts.

  • \(A \subseteq E \times S\): The availability of employees for shifts, indicating which employees are available for which shifts.

  • \(r_j\): The required number of employees for each shift \(j \in S\).

  • \(t\): The total tolerance for the number of employees assigned to a shift, allowing for some under-assignment.

  • \(t_j\): The allowed slack for each shift \(j \in S\), specifying the maximum understaffing permitted for each shift.

Variables#

  • \(x_{ij} \in \{0, 1\}\): A binary variable for each pair \((i, j) \in A\), where \(x_{ij} = 1\) if employee \(i\) is assigned to shift \(j\), and \(0\) otherwise.

  • \(s_j \geq 0\): A slack variable representing the understaffing of shift \(j\).

  • \(d^{\min} \geq 0\): The smallest number of shifts assigned to any employee.

  • \(d^{\max} \geq 0\): The largest number of shifts assigned to any employee.

Objective#

The objective is to minimize the difference between the maximum and minimum number of shifts assigned to any employee:

\[\min d^{\max} - d^{\min}\]

Constraints#

  1. Shift Requirements: Each shift must meet its staffing requirement, allowing for some slack:

    \[\sum_{i \in E} x_{ij} + s_j = r_j \quad \forall j \in S\]
  2. Slack Limitation: The total slack across all shifts must not exceed the tolerance \(t\):

    \[\sum_{j \in S} s_j \leq t\]
  3. Slack Per Shift: The slack for each shift must not exceed the allowed slack for that shift:

    \[s_j \leq t_j \quad \forall j \in S\]
  4. Min/Max Shifts per Employee: Each employee must be assigned a number of shifts between the minimum and maximum:

    \[d^{\min} \leq \sum_{j \in S} x_{ij} \leq d^{\max} \quad \forall i \in E\]

Model Implementation#

In order to implement the model, we first need to define the input data, variables, and constraints.

Input Data#

We start by defining the input data class.

class StaffSchedulingInput(InputData):
    employees: list[str]  # The employees to assign to shifts
    shifts: list[str]  # The shifts to assign employees to
    availability: list[tuple[str, str]]  # The availability of employees for shifts
    shift_requirements: dict[str, int]  # The required number of employees for each shift
    employee_available_for_shift: dict[str, list[str]]  # The employees available for each shift
    shift_feasible_for_employee: dict[str, list[str]]  # The shifts feasible for each employee

    @staticmethod
    def from_csvs(path_to_directory: str) -> StaffSchedulingInput:
        """Utility method to load the input data from CSV files."""
        # Load CSV files directly
        employees_df = pd.read_csv(os.path.join(path_to_directory, "employees.csv"))
        shifts_df = pd.read_csv(os.path.join(path_to_directory, "shifts.csv"))
        availability_df = pd.read_csv(os.path.join(path_to_directory, "availability.csv"))

        # Extract data from loaded CSVs
        employees = employees_df["employee"].tolist()
        shifts = shifts_df["shift"].tolist()
        availability = [tuple(item) for item in availability_df.values.tolist()]
        shift_requirements = dict(shifts_df[["shift", "required_employees"]].values.tolist())

        # Process availabilities
        employee_available_for_shift, shift_feasible_for_employee = StaffSchedulingInput.process_availabilities(
            availability
        )

        # Return a new instance of StaffSchedulingInput
        return StaffSchedulingInput(
            employees=employees,
            shifts=shifts,
            availability=availability,
            shift_requirements=shift_requirements,
            employee_available_for_shift=employee_available_for_shift,
            shift_feasible_for_employee=shift_feasible_for_employee,
        )

    @staticmethod
    def load_example() -> StaffSchedulingInput:
        """Load example data.

        Returns:
            StaffSchedulingInput: The example data
        """
        return StaffSchedulingInput.from_csvs(files(staff_scheduling).joinpath("data_tables"))

    @hide_from_ai
    @staticmethod
    def process_availabilities(
        availability: list[tuple[str, str]],
    ) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
        employee_available_for_shift = defaultdict[str, list[str]](list)
        for employee, shift in availability:
            if shift not in employee_available_for_shift:
                employee_available_for_shift[shift] = []
            employee_available_for_shift[shift].append(employee)

        shift_feasible_for_employee = defaultdict[str, list[str]](list)
        for employee, shift in availability:
            if employee not in shift_feasible_for_employee:
                shift_feasible_for_employee[employee] = []
            shift_feasible_for_employee[employee].append(shift)

        return employee_available_for_shift, shift_feasible_for_employee

The input data class inherits from decision_ai.InputData, which provides the basic structure for defining input data in DecisionAI models.

Variables#

We then define the variables pydantic model. We can see different types of variable definitions, including a dictionary (slack variable), a dictionary of dictionaries (assignment variables), and single variables (min / max total shifts). For each variable, we define a static method init_{variable_name} that returns the initialized variable attribute given the input data.

class StaffSchedulingVariables(PulpVariables):
    assigned_to: dict[str, dict[str, pulp.LpVariable]] = Field(
        ..., description="Mapping of employees to shifts with assignment variables"
    )
    slack_positive_shift_requirements: dict[str, pulp.LpVariable] = Field(
        ...,
        description="Positive slack variables for each shift to meet shift requirements",
    )
    slack_negative_shift_requirements: dict[str, pulp.LpVariable] = Field(
        ...,
        description="Negative slack variables for each shift to meet shift requirements",
    )
    min_total_shifts_per_employee: pulp.LpVariable = Field(
        ..., description="Minimum total shifts assigned to any employee"
    )
    max_total_shifts_per_employee: pulp.LpVariable = Field(
        ..., description="Maximum total shifts assigned to any employee"
    )

    @staticmethod
    def init_assigned_to(input_: StaffSchedulingInput) -> dict[str, dict[str, pulp.LpVariable]]:
        assigned_to: dict[str, dict[str, pulp.LpVariable]] = {}
        for employee in input_.employees:
            assigned_to[employee] = {}
            for shift in input_.shifts:
                assigned_to[employee][shift] = pulp.LpVariable(f"{employee}_assigned_to_{shift}", cat=pulp.LpBinary)
        return assigned_to

    @staticmethod
    def init_slack_positive_shift_requirements(input_: StaffSchedulingInput) -> dict[str, pulp.LpVariable]:
        slack_positive_shift_requirements = {}
        for shift in input_.shifts:
            slack_positive_shift_requirements[shift] = pulp.LpVariable(
                f"slack_positive_shift_requirements_for_shift_{shift}",
                lowBound=0,
                cat=pulp.LpContinuous,
            )
        return slack_positive_shift_requirements

    @staticmethod
    def init_slack_negative_shift_requirements(input_: StaffSchedulingInput) -> dict[str, pulp.LpVariable]:
        slack_negative_shift_requirements = {}
        for shift in input_.shifts:
            slack_negative_shift_requirements[shift] = pulp.LpVariable(
                f"slack_negative_shift_requirements_for_shift_{shift}",
                lowBound=0,
                cat=pulp.LpContinuous,
            )
        return slack_negative_shift_requirements

    @staticmethod
    def init_min_total_shifts_per_employee(input_: StaffSchedulingInput) -> pulp.LpVariable:  # noqa: ARG004
        return pulp.LpVariable(
            "min_total_shifts_for_any_employee",
            lowBound=0,
            cat=pulp.LpContinuous,
        )

    @staticmethod
    def init_max_total_shifts_per_employee(input_: StaffSchedulingInput) -> pulp.LpVariable:  # noqa: ARG004
        return pulp.LpVariable(
            "max_total_shifts_for_any_employee",
            lowBound=0,
            cat=pulp.LpContinuous,
        )

The variables class inherits from decision_ai.PulpVariables, which provides the framework for defining optimization variables in DecisionAI models.

Model Class#

Finally, we define the model class. Note that we attach the variables class to the model class by assigning it to the variables_class attribute.

class StaffSchedulingModel(PulpDecisionAIModel[StaffSchedulingInput, StaffSchedulingVariables]):
    variables_class = StaffSchedulingVariables

    def __init__(self):
        super().__init__()

    def solution_to_str(self, input_: StaffSchedulingInput, solution: Solution) -> str:
        solution_display = "##### Solution\n\n"
        solution_display += f"Objective value: {solution.objective}\n\n"
        solution_display += f"Status: {solution.status}\n\n"

        if solution.status == "Infeasible":
            return solution_display

        assignments = {
            employee: {shift: solution.variables.assigned_to[employee][shift] for shift in input_.shifts}
            for employee in input_.employees
        }

        solution_display += "Assignments:\n\n"
        for employee, shifts in assignments.items():
            solution_display += f" * {employee}: "
            for shift, var in shifts.items():
                if var > 0:
                    solution_display += f"{shift}, "
            if solution_display.endswith(", "):
                solution_display = solution_display[:-2] + "\n\n"

        return solution_display

    @constraint
    def shift_requirements(input_: StaffSchedulingInput, variables: StaffSchedulingVariables) -> ConstraintGenerator:
        # Ensure the number of employees assigned to a shift fulfills the shift requirements
        for shift in input_.shifts:
            number_of_employee_assigned_to_shift = pulp.lpSum(
                [variables.assigned_to[employee][shift] for employee in input_.employee_available_for_shift[shift]],
            )
            yield (
                number_of_employee_assigned_to_shift
                + variables.slack_positive_shift_requirements[shift]
                - variables.slack_negative_shift_requirements[shift]
                == input_.shift_requirements[shift],
                f"Shift requirements {shift}",
            )

    @constraint
    def enforce_unavailabilities(
        input_: StaffSchedulingInput,
        variables: StaffSchedulingVariables,
    ) -> ConstraintGenerator:
        # Force assigned_to variables to 0 when employee is not available for shift
        for shift in input_.shifts:
            for employee in input_.employees:
                if employee not in input_.employee_available_for_shift[shift]:
                    yield (
                        variables.assigned_to[employee][shift] == 0,
                        f"Employee {employee} not available for shift {shift}",
                    )

    @constraint
    def min_max(input_: StaffSchedulingInput, variables: StaffSchedulingVariables) -> ConstraintGenerator:
        #  Min/max number of shifts per employee
        total_shifts = {}
        for employee in input_.employees:
            total_shifts[employee] = pulp.lpSum(
                [variables.assigned_to[employee][shift] for shift in input_.shift_feasible_for_employee[employee]],
            )
            yield (
                variables.min_total_shifts_per_employee <= total_shifts[employee],
                f"Minimum total shifts for employee {employee}",
            )
            yield (
                total_shifts[employee] <= variables.max_total_shifts_per_employee,
                f"Maximum total shifts for employee {employee}",
            )

    def objective_part_number_of_shifts_variance(
        self,
        variables: StaffSchedulingVariables,
    ) -> pulp.LpAffineExpression:
        # Minimize the difference between max and min shifts per employee for fairness
        return variables.max_total_shifts_per_employee - variables.min_total_shifts_per_employee

    def objective_part_slack_penalty(
        self,
        variables: StaffSchedulingVariables,
    ) -> pulp.LpAffineExpression:
        # Penalize deviations from shift requirements using slack variables
        slack_penalty_weight = 5.0
        return slack_penalty_weight * (
            pulp.lpSum(variables.slack_positive_shift_requirements.values())
            + pulp.lpSum(variables.slack_negative_shift_requirements.values())
        )

    def set_up_objective(
        self,
        input_: StaffSchedulingInput,
        prob: pulp.LpProblem,
        variables: StaffSchedulingVariables,
    ) -> pulp.LpProblem:
        # Link all objective parts that begin with objective_part
        number_of_shifts_variance = self.objective_part_number_of_shifts_variance(variables)
        slack_penalty = self.objective_part_slack_penalty(variables)

        # Sum all objective components
        total_objective = number_of_shifts_variance + slack_penalty

        prob += (total_objective, "Variance of number of shifts per employee + slack penalty")

        return prob

    @hide_from_ai
    @staticmethod
    def display_description() -> None:
        console = Console()
        markdown_description = files(staff_scheduling).joinpath("MODEL.md").read_text()
        console.print(Markdown(markdown_description))

    @hide_from_ai
    @staticmethod
    def get_description() -> str:
        return files(staff_scheduling).joinpath("MODEL.md").read_text()

The model class inherits from decision_ai.PulpDecisionAIModel, which provides the core optimization modeling capabilities. The constraints are defined using the decision_ai.constraint() decorator, which marks methods as constraint generators.

Note

We include the method process_availabilities as a helper method that we use in the input data methods. We use the decision_ai.hide_from_ai() decorator to exclude it from the context passed to the LLMs, as it doesn’t provide any useful information for them.

Complete Example#

# ruff: noqa: ARG002, N805
# Ignore unused arguments for required interface signatures
# Ignore that first argument must be called self (ruff does not recognize that @constraint is a @staticmethod)

from __future__ import annotations

import os
from collections import defaultdict
from importlib.resources import files
from typing import TYPE_CHECKING

import pandas as pd
import pulp
from pydantic import Field
from rich.console import Console
from rich.markdown import Markdown

from decision_ai import InputData, PulpDecisionAIModel, PulpVariables, constraint, hide_from_ai
from decision_ai.examples import staff_scheduling

if TYPE_CHECKING:
    from decision_ai import Solution
    from decision_ai.typing import ConstraintGenerator


class StaffSchedulingInput(InputData):
    employees: list[str]  # The employees to assign to shifts
    shifts: list[str]  # The shifts to assign employees to
    availability: list[tuple[str, str]]  # The availability of employees for shifts
    shift_requirements: dict[str, int]  # The required number of employees for each shift
    employee_available_for_shift: dict[str, list[str]]  # The employees available for each shift
    shift_feasible_for_employee: dict[str, list[str]]  # The shifts feasible for each employee

    @staticmethod
    def from_csvs(path_to_directory: str) -> StaffSchedulingInput:
        """Utility method to load the input data from CSV files."""
        # Load CSV files directly
        employees_df = pd.read_csv(os.path.join(path_to_directory, "employees.csv"))
        shifts_df = pd.read_csv(os.path.join(path_to_directory, "shifts.csv"))
        availability_df = pd.read_csv(os.path.join(path_to_directory, "availability.csv"))

        # Extract data from loaded CSVs
        employees = employees_df["employee"].tolist()
        shifts = shifts_df["shift"].tolist()
        availability = [tuple(item) for item in availability_df.values.tolist()]
        shift_requirements = dict(shifts_df[["shift", "required_employees"]].values.tolist())

        # Process availabilities
        employee_available_for_shift, shift_feasible_for_employee = StaffSchedulingInput.process_availabilities(
            availability
        )

        # Return a new instance of StaffSchedulingInput
        return StaffSchedulingInput(
            employees=employees,
            shifts=shifts,
            availability=availability,
            shift_requirements=shift_requirements,
            employee_available_for_shift=employee_available_for_shift,
            shift_feasible_for_employee=shift_feasible_for_employee,
        )

    @staticmethod
    def load_example() -> StaffSchedulingInput:
        """Load example data.

        Returns:
            StaffSchedulingInput: The example data
        """
        return StaffSchedulingInput.from_csvs(files(staff_scheduling).joinpath("data_tables"))

    @hide_from_ai
    @staticmethod
    def process_availabilities(
        availability: list[tuple[str, str]],
    ) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
        employee_available_for_shift = defaultdict[str, list[str]](list)
        for employee, shift in availability:
            if shift not in employee_available_for_shift:
                employee_available_for_shift[shift] = []
            employee_available_for_shift[shift].append(employee)

        shift_feasible_for_employee = defaultdict[str, list[str]](list)
        for employee, shift in availability:
            if employee not in shift_feasible_for_employee:
                shift_feasible_for_employee[employee] = []
            shift_feasible_for_employee[employee].append(shift)

        return employee_available_for_shift, shift_feasible_for_employee


class StaffSchedulingVariables(PulpVariables):
    assigned_to: dict[str, dict[str, pulp.LpVariable]] = Field(
        ..., description="Mapping of employees to shifts with assignment variables"
    )
    slack_positive_shift_requirements: dict[str, pulp.LpVariable] = Field(
        ...,
        description="Positive slack variables for each shift to meet shift requirements",
    )
    slack_negative_shift_requirements: dict[str, pulp.LpVariable] = Field(
        ...,
        description="Negative slack variables for each shift to meet shift requirements",
    )
    min_total_shifts_per_employee: pulp.LpVariable = Field(
        ..., description="Minimum total shifts assigned to any employee"
    )
    max_total_shifts_per_employee: pulp.LpVariable = Field(
        ..., description="Maximum total shifts assigned to any employee"
    )

    @staticmethod
    def init_assigned_to(input_: StaffSchedulingInput) -> dict[str, dict[str, pulp.LpVariable]]:
        assigned_to: dict[str, dict[str, pulp.LpVariable]] = {}
        for employee in input_.employees:
            assigned_to[employee] = {}
            for shift in input_.shifts:
                assigned_to[employee][shift] = pulp.LpVariable(f"{employee}_assigned_to_{shift}", cat=pulp.LpBinary)
        return assigned_to

    @staticmethod
    def init_slack_positive_shift_requirements(input_: StaffSchedulingInput) -> dict[str, pulp.LpVariable]:
        slack_positive_shift_requirements = {}
        for shift in input_.shifts:
            slack_positive_shift_requirements[shift] = pulp.LpVariable(
                f"slack_positive_shift_requirements_for_shift_{shift}",
                lowBound=0,
                cat=pulp.LpContinuous,
            )
        return slack_positive_shift_requirements

    @staticmethod
    def init_slack_negative_shift_requirements(input_: StaffSchedulingInput) -> dict[str, pulp.LpVariable]:
        slack_negative_shift_requirements = {}
        for shift in input_.shifts:
            slack_negative_shift_requirements[shift] = pulp.LpVariable(
                f"slack_negative_shift_requirements_for_shift_{shift}",
                lowBound=0,
                cat=pulp.LpContinuous,
            )
        return slack_negative_shift_requirements

    @staticmethod
    def init_min_total_shifts_per_employee(input_: StaffSchedulingInput) -> pulp.LpVariable:  # noqa: ARG004
        return pulp.LpVariable(
            "min_total_shifts_for_any_employee",
            lowBound=0,
            cat=pulp.LpContinuous,
        )

    @staticmethod
    def init_max_total_shifts_per_employee(input_: StaffSchedulingInput) -> pulp.LpVariable:  # noqa: ARG004
        return pulp.LpVariable(
            "max_total_shifts_for_any_employee",
            lowBound=0,
            cat=pulp.LpContinuous,
        )


class StaffSchedulingModel(PulpDecisionAIModel[StaffSchedulingInput, StaffSchedulingVariables]):
    variables_class = StaffSchedulingVariables

    def __init__(self):
        super().__init__()

    def solution_to_str(self, input_: StaffSchedulingInput, solution: Solution) -> str:
        solution_display = "##### Solution\n\n"
        solution_display += f"Objective value: {solution.objective}\n\n"
        solution_display += f"Status: {solution.status}\n\n"

        if solution.status == "Infeasible":
            return solution_display

        assignments = {
            employee: {shift: solution.variables.assigned_to[employee][shift] for shift in input_.shifts}
            for employee in input_.employees
        }

        solution_display += "Assignments:\n\n"
        for employee, shifts in assignments.items():
            solution_display += f" * {employee}: "
            for shift, var in shifts.items():
                if var > 0:
                    solution_display += f"{shift}, "
            if solution_display.endswith(", "):
                solution_display = solution_display[:-2] + "\n\n"

        return solution_display

    @constraint
    def shift_requirements(input_: StaffSchedulingInput, variables: StaffSchedulingVariables) -> ConstraintGenerator:
        # Ensure the number of employees assigned to a shift fulfills the shift requirements
        for shift in input_.shifts:
            number_of_employee_assigned_to_shift = pulp.lpSum(
                [variables.assigned_to[employee][shift] for employee in input_.employee_available_for_shift[shift]],
            )
            yield (
                number_of_employee_assigned_to_shift
                + variables.slack_positive_shift_requirements[shift]
                - variables.slack_negative_shift_requirements[shift]
                == input_.shift_requirements[shift],
                f"Shift requirements {shift}",
            )

    @constraint
    def enforce_unavailabilities(
        input_: StaffSchedulingInput,
        variables: StaffSchedulingVariables,
    ) -> ConstraintGenerator:
        # Force assigned_to variables to 0 when employee is not available for shift
        for shift in input_.shifts:
            for employee in input_.employees:
                if employee not in input_.employee_available_for_shift[shift]:
                    yield (
                        variables.assigned_to[employee][shift] == 0,
                        f"Employee {employee} not available for shift {shift}",
                    )

    @constraint
    def min_max(input_: StaffSchedulingInput, variables: StaffSchedulingVariables) -> ConstraintGenerator:
        #  Min/max number of shifts per employee
        total_shifts = {}
        for employee in input_.employees:
            total_shifts[employee] = pulp.lpSum(
                [variables.assigned_to[employee][shift] for shift in input_.shift_feasible_for_employee[employee]],
            )
            yield (
                variables.min_total_shifts_per_employee <= total_shifts[employee],
                f"Minimum total shifts for employee {employee}",
            )
            yield (
                total_shifts[employee] <= variables.max_total_shifts_per_employee,
                f"Maximum total shifts for employee {employee}",
            )

    def objective_part_number_of_shifts_variance(
        self,
        variables: StaffSchedulingVariables,
    ) -> pulp.LpAffineExpression:
        # Minimize the difference between max and min shifts per employee for fairness
        return variables.max_total_shifts_per_employee - variables.min_total_shifts_per_employee

    def objective_part_slack_penalty(
        self,
        variables: StaffSchedulingVariables,
    ) -> pulp.LpAffineExpression:
        # Penalize deviations from shift requirements using slack variables
        slack_penalty_weight = 5.0
        return slack_penalty_weight * (
            pulp.lpSum(variables.slack_positive_shift_requirements.values())
            + pulp.lpSum(variables.slack_negative_shift_requirements.values())
        )

    def set_up_objective(
        self,
        input_: StaffSchedulingInput,
        prob: pulp.LpProblem,
        variables: StaffSchedulingVariables,
    ) -> pulp.LpProblem:
        # Link all objective parts that begin with objective_part
        number_of_shifts_variance = self.objective_part_number_of_shifts_variance(variables)
        slack_penalty = self.objective_part_slack_penalty(variables)

        # Sum all objective components
        total_objective = number_of_shifts_variance + slack_penalty

        prob += (total_objective, "Variance of number of shifts per employee + slack penalty")

        return prob

    @hide_from_ai
    @staticmethod
    def display_description() -> None:
        console = Console()
        markdown_description = files(staff_scheduling).joinpath("MODEL.md").read_text()
        console.print(Markdown(markdown_description))

    @hide_from_ai
    @staticmethod
    def get_description() -> str:
        return files(staff_scheduling).joinpath("MODEL.md").read_text()