RL Graph Environment: Learning from many instancesΒΆ

[1]:
from job_shop_lib.reinforcement_learning import (
    MultiJobShopGraphEnv,
    ObservationSpaceKey,
    ObservationDict,
)
from job_shop_lib.dispatching import DispatcherObserverConfig
from job_shop_lib.dispatching.feature_observers import (
    FeatureObserverType,
    FeatureType,
)
from job_shop_lib.graphs import build_agent_task_graph
from job_shop_lib.generation import GeneralInstanceGenerator
[2]:
generator = GeneralInstanceGenerator(
    num_jobs=(3, 6), num_machines=(3, 5), allow_less_jobs_than_machines=False
)
feature_observer_configs = [
    DispatcherObserverConfig(
        FeatureObserverType.IS_READY,
        kwargs={"feature_types": [FeatureType.JOBS]},
    )
]

env = MultiJobShopGraphEnv(
    instance_generator=generator,
    feature_observer_configs=feature_observer_configs,
    graph_initializer=build_agent_task_graph,
    render_mode="human",  # Try "save_video"
    render_config={"video_config": {"fps": 4}},
)
[3]:
import random

random.seed(100)


def random_action(observation: ObservationDict) -> tuple[int, int]:
    ready_operations = []
    for operation_id, is_ready in enumerate(
        observation[ObservationSpaceKey.JOBS.value].ravel()
    ):
        if is_ready == 1.0:
            ready_operations.append(operation_id)

    operation_id = random.choice(ready_operations)
    machine_id = -1  # We can use -1 if each operation can only be scheduled
    # in one machine.
    return (operation_id, machine_id)
[4]:
from IPython.display import clear_output

n_episodes = 3
instances = []
total_rewards = []
for _ in range(n_episodes):
    done = False
    obs, _ = env.reset()
    while not done:
        action = random_action(obs)
        obs, reward, done, *_ = env.step(action)
        if env.render_mode == "human":
            env.render()
            clear_output(wait=True)
    if env.render_mode == "save_video" or env.render_mode == "save_gif":
        env.render()

    instances.append(env.dispatcher.instance)
    total_rewards.append(sum(env.reward_function.rewards))
../_images/examples_10-MultiJobShopGraphEnv_4_0.png
[7]:
for instance, total_reward in zip(instances, total_rewards):
    print(f"Instance: {instance}")
    print(f"Total reward: {total_reward}")
    print()
Instance: JobShopInstance(name=classic_generated_instance_2, num_jobs=4, num_machines=5)
Total reward: -495

Instance: JobShopInstance(name=classic_generated_instance_3, num_jobs=6, num_machines=5)
Total reward: -666

Instance: JobShopInstance(name=classic_generated_instance_4, num_jobs=4, num_machines=4)
Total reward: -312

[8]:
env.reset()
[8]:
({'removed_nodes': array([False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,
         False, False,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True]),
  'edge_index': array([[ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,
           3,  3,  3,  3,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  6,  6,
           6,  6,  6,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  9,  9,  9,
           9,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12,
          12, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15,
          15, 15, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17,
          18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1],
         [16,  1,  2,  3,  4, 17,  0,  2,  3,  4, 18,  0,  1,  3,  4, 19,
           0,  1,  2,  4, 15,  0,  1,  2,  3, 15,  6,  7,  8,  9, 19,  5,
           7,  8,  9, 16,  5,  6,  8,  9, 17,  5,  6,  7,  9, 18,  5,  6,
           7,  8, 15, 11, 12, 13, 14, 16, 10, 12, 13, 14, 19, 10, 11, 13,
          14, 17, 10, 11, 12, 14, 18, 10, 11, 12, 13,  4,  5, 10, 16, 17,
          18, 19,  0,  7, 11, 15, 17, 18, 19,  1,  8, 13, 15, 16, 18, 19,
           2,  9, 14, 15, 16, 17, 19,  3,  6, 12, 15, 16, 17, 18, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1]], dtype=int32),
  'jobs': array([[ 1.],
         [ 1.],
         [ 1.],
         [-1.],
         [-1.],
         [-1.]], dtype=float32)},
 {})
[9]:
for key, array in obs.items():
    print(key, array.shape)
removed_nodes (35,)
edge_index (2, 200)
jobs (6, 1)
[10]:
env.observation_space
[10]:
Dict('edge_index': MultiDiscrete([[36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36]
 [36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
  36 36 36 36 36 36 36 36]], start=[[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
  -1 -1 -1 -1 -1 -1 -1 -1]]), 'jobs': Box(-inf, inf, (6, 1), float32), 'removed_nodes': MultiBinary(35))
[ ]: