Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 111 additions & 12 deletions torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@
from torchx.util.strings import normalize_str
from torchx.workspace.docker_workspace import DockerWorkspaceMixin


if TYPE_CHECKING:
from docker import DockerClient
from kubernetes.client import ApiClient, CustomObjectsApi
Expand All @@ -159,6 +158,7 @@
)
from kubernetes.client.rest import ApiException


logger: logging.Logger = logging.getLogger(__name__)

# Kubernetes reserves a small amount of resources per host for the system. For
Expand Down Expand Up @@ -294,7 +294,14 @@ def sanitize_for_serialization(obj: object) -> object:
return api.sanitize_for_serialization(obj)


def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod":
def role_to_pod(
name: str,
role: Role,
service_account: Optional[str],
reserved_millicpu: int = RESERVED_MILLICPU,
reserved_memmb: int = RESERVED_MEMMB,
efa_device_count: Optional[int] = None,
) -> "V1Pod":
from kubernetes.client.models import ( # noqa: F811 redefinition of unused
V1Container,
V1ContainerPort,
Expand Down Expand Up @@ -324,18 +331,29 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod
if resource.cpu > 0:
mcpu = int(resource.cpu * 1000)
limits["cpu"] = f"{mcpu}m"
request_mcpu = max(mcpu - RESERVED_MILLICPU, 0)
request_mcpu = max(mcpu - reserved_millicpu, 0)
requests["cpu"] = f"{request_mcpu}m"
if resource.memMB > 0:
limits["memory"] = f"{int(resource.memMB)}M"
request_memMB = max(int(resource.memMB) - RESERVED_MEMMB, 0)
request_memMB = max(int(resource.memMB) - reserved_memmb, 0)
requests["memory"] = f"{request_memMB}M"
if resource.gpu > 0:
requests["nvidia.com/gpu"] = limits["nvidia.com/gpu"] = str(resource.gpu)

EFA_DEVICE = "vpc.amazonaws.com/efa"
for device_name, device_limit in resource.devices.items():
limits[device_name] = str(device_limit)

# Handle EFA device count override:
# - None (default): use whatever count is in the resource spec (already added above)
# - 0: remove EFA devices entirely
# - N > 0: set EFA device count to N (override or add)
if efa_device_count is not None:
if efa_device_count == 0:
limits.pop(EFA_DEVICE, None)
else:
limits[EFA_DEVICE] = str(efa_device_count)

resources = V1ResourceRequirements(
limits=limits,
requests=requests,
Expand Down Expand Up @@ -475,6 +493,9 @@ def app_to_resource(
queue: str,
service_account: Optional[str],
priority_class: Optional[str] = None,
reserved_millicpu: int = RESERVED_MILLICPU,
reserved_memmb: int = RESERVED_MEMMB,
efa_device_count: Optional[int] = None,
) -> Dict[str, Any]:
"""
app_to_resource creates a volcano job kubernetes resource definition from
Expand Down Expand Up @@ -507,7 +528,14 @@ def app_to_resource(
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
replica_role.env["TORCHX_IMAGE"] = replica_role.image

pod = role_to_pod(name, replica_role, service_account)
pod = role_to_pod(
name,
replica_role,
service_account,
reserved_millicpu,
reserved_memmb,
efa_device_count,
)
if k8s_metadata := role.metadata.get("kubernetes"):
if isinstance(k8s_metadata, str):
import fsspec
Expand Down Expand Up @@ -589,6 +617,9 @@ class KubernetesOpts(TypedDict, total=False):
service_account: Optional[str]
priority_class: Optional[str]
validate_spec: Optional[bool]
reserved_millicpu: Optional[int]
reserved_memmb: Optional[int]
efa_device_count: Optional[int]


class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]):
Expand Down Expand Up @@ -707,9 +738,14 @@ def _api_client(self) -> "ApiClient":
if c is None:
configuration = client.Configuration()
try:
config.load_kube_config(client_configuration=configuration)
except config.ConfigException as e:
warnings.warn(f"failed to load kube config: {e}")
# Try in-cluster config first (for pods with ServiceAccount)
config.load_incluster_config(client_configuration=configuration)
except config.ConfigException:
# Fall back to kubeconfig (for local development)
try:
config.load_kube_config(client_configuration=configuration)
except config.ConfigException as e:
warnings.warn(f"failed to load kube config: {e}")

c = self._client = client.ApiClient(configuration)

Expand Down Expand Up @@ -783,7 +819,26 @@ def _submit_dryrun(
priority_class, str
), "priority_class must be a str"

resource = app_to_resource(app, queue, service_account, priority_class)
reserved_millicpu = cfg.get("reserved_millicpu", RESERVED_MILLICPU)
assert isinstance(reserved_millicpu, int), "reserved_millicpu must be an int"

reserved_memmb = cfg.get("reserved_memmb", RESERVED_MEMMB)
assert isinstance(reserved_memmb, int), "reserved_memmb must be an int"

efa_device_count = cfg.get("efa_device_count")
assert efa_device_count is None or isinstance(
efa_device_count, int
), "efa_device_count must be an int or None"

resource = app_to_resource(
app,
queue,
service_account,
priority_class,
reserved_millicpu,
reserved_memmb,
efa_device_count,
)

if cfg.get("validate_spec"):
try:
Expand Down Expand Up @@ -889,9 +944,29 @@ def _run_opts(self) -> runopts:
help="Validate job spec using Kubernetes API dry-run before submission",
default=True,
)
opts.add(
"reserved_millicpu",
type_=int,
help="Amount of CPU in millicores to reserve for Kubernetes system overhead (default: 100)",
default=RESERVED_MILLICPU,
)
opts.add(
"reserved_memmb",
type_=int,
help="Amount of memory in MB to reserve for Kubernetes system overhead (default: 1024)",
default=RESERVED_MEMMB,
)
opts.add(
"efa_device_count",
type_=int,
help="EFA device count override: None/unset=use resource spec, "
"0=remove EFA, N>0=set EFA count to N",
default=None,
)
return opts

def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
from kubernetes import client
from kubernetes.client.rest import ApiException

namespace, name = app_id.split(":")
Expand All @@ -917,8 +992,8 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
TASK_STATUS_COUNT = "taskStatusCount"

if TASK_STATUS_COUNT in status:
for name, status in status[TASK_STATUS_COUNT].items():
role, _, idx = name.rpartition("-")
for task_name, status in status[TASK_STATUS_COUNT].items():
role, _, idx = task_name.rpartition("-")

state_str = next(iter(status["phase"].keys()))
state = TASK_STATE[state_str]
Expand All @@ -927,8 +1002,32 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
roles[role] = Role(name=role, num_replicas=0, image="")
roles_statuses[role] = RoleStatus(role, [])
roles[role].num_replicas += 1

# Pod name follows the pattern: {job_name}-{task_name}-0
# Get the pod to retrieve its IP address
pod_name_k8s = f"{name}-{task_name}-0"
try:
core_api = client.CoreV1Api(self._api_client())
pod = core_api.read_namespaced_pod(
name=pod_name_k8s, namespace=namespace
)
pod_ip = pod.status.pod_ip

# Convert IP to dashed format (e.g., 10.244.1.5 -> 10-244-1-5)
pod_ip_dashed = pod_ip.replace(".", "-")

# Kubernetes DNS = <pod-ip-dashed>.<namespace>.pod.cluster.local
# Note: This will only be useful if the client using the IPs in in the cluster.
hostname = f"{pod_ip_dashed}.{namespace}.pod.cluster.local"

except ApiException:
# Fallback to old behavior if pod not found
hostname = ""

roles_statuses[role].replicas.append(
ReplicaStatus(id=int(idx), role=role, state=state, hostname="")
ReplicaStatus(
id=int(idx), role=role, state=state, hostname=hostname
)
)
else:
app_state = AppState.UNKNOWN
Expand Down
Loading
Loading