RL Graph Environment: Learning from many instancesΒΆ
[1]:
from job_shop_lib.reinforcement_learning import (
MultiJobShopGraphEnv,
ObservationSpaceKey,
ObservationDict,
)
from job_shop_lib.dispatching import DispatcherObserverConfig
from job_shop_lib.dispatching.feature_observers import (
FeatureObserverType,
FeatureType,
)
from job_shop_lib.graphs import build_agent_task_graph
from job_shop_lib.generation import GeneralInstanceGenerator
[2]:
generator = GeneralInstanceGenerator(
num_jobs=(3, 6), num_machines=(3, 5), allow_less_jobs_than_machines=False
)
feature_observer_configs = [
DispatcherObserverConfig(
FeatureObserverType.IS_READY,
kwargs={"feature_types": [FeatureType.JOBS]},
)
]
env = MultiJobShopGraphEnv(
instance_generator=generator,
feature_observer_configs=feature_observer_configs,
graph_initializer=build_agent_task_graph,
render_mode="human", # Try "save_video"
render_config={"video_config": {"fps": 4}},
)
[3]:
import random
random.seed(100)
def random_action(observation: ObservationDict) -> tuple[int, int]:
ready_operations = []
for operation_id, is_ready in enumerate(
observation[ObservationSpaceKey.JOBS.value].ravel()
):
if is_ready == 1.0:
ready_operations.append(operation_id)
operation_id = random.choice(ready_operations)
machine_id = -1 # We can use -1 if each operation can only be scheduled
# in one machine.
return (operation_id, machine_id)
[4]:
from IPython.display import clear_output
n_episodes = 3
instances = []
total_rewards = []
for _ in range(n_episodes):
done = False
obs, _ = env.reset()
while not done:
action = random_action(obs)
obs, reward, done, *_ = env.step(action)
if env.render_mode == "human":
env.render()
clear_output(wait=True)
if env.render_mode == "save_video" or env.render_mode == "save_gif":
env.render()
instances.append(env.dispatcher.instance)
total_rewards.append(sum(env.reward_function.rewards))
[7]:
for instance, total_reward in zip(instances, total_rewards):
print(f"Instance: {instance}")
print(f"Total reward: {total_reward}")
print()
Instance: JobShopInstance(name=classic_generated_instance_2, num_jobs=4, num_machines=5)
Total reward: -495
Instance: JobShopInstance(name=classic_generated_instance_3, num_jobs=6, num_machines=5)
Total reward: -666
Instance: JobShopInstance(name=classic_generated_instance_4, num_jobs=4, num_machines=4)
Total reward: -312
[8]:
env.reset()
[8]:
({'removed_nodes': array([False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True]),
'edge_index': array([[ 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3,
3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6,
6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9,
9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12,
12, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15,
15, 15, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17,
18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1],
[16, 1, 2, 3, 4, 17, 0, 2, 3, 4, 18, 0, 1, 3, 4, 19,
0, 1, 2, 4, 15, 0, 1, 2, 3, 15, 6, 7, 8, 9, 19, 5,
7, 8, 9, 16, 5, 6, 8, 9, 17, 5, 6, 7, 9, 18, 5, 6,
7, 8, 15, 11, 12, 13, 14, 16, 10, 12, 13, 14, 19, 10, 11, 13,
14, 17, 10, 11, 12, 14, 18, 10, 11, 12, 13, 4, 5, 10, 16, 17,
18, 19, 0, 7, 11, 15, 17, 18, 19, 1, 8, 13, 15, 16, 18, 19,
2, 9, 14, 15, 16, 17, 19, 3, 6, 12, 15, 16, 17, 18, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1]], dtype=int32),
'jobs': array([[ 1.],
[ 1.],
[ 1.],
[-1.],
[-1.],
[-1.]], dtype=float32)},
{})
[9]:
for key, array in obs.items():
print(key, array.shape)
removed_nodes (35,)
edge_index (2, 200)
jobs (6, 1)
[10]:
env.observation_space
[10]:
Dict('edge_index': MultiDiscrete([[36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36]
[36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36 36
36 36 36 36 36 36 36 36]], start=[[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1]
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1]]), 'jobs': Box(-inf, inf, (6, 1), float32), 'removed_nodes': MultiBinary(35))
[ ]: