---
name: torch-xla
description: >
  Expert guide for PyTorch/XLA (torch_xla) development targeting TPU and XLA-compatible devices.
  Use this skill whenever the user asks about: torch_xla, PyTorch on TPU, XLA training, SPMD
  parallelism, distributed TPU training, XLA checkpointing, MpDeviceLoader, xm.mark_step,
  torch_xla.launch, FSDPv2 on XLA, gradient sharding, XLA compilation, lazy tensor execution,
  or migrating GPU PyTorch code to TPU. Trigger even for partial mentions like "TPU training",
  "XLA device", "xla backend", "run on TPU", or "SPMD mesh".
  This skill covers torch_xla 2.9.0 (latest stable, requires torch==2.9.0, Python 3.10–3.13).
---

# torch_xla 스킬 (v2.9.0)

PyTorch/XLA를 사용해 TPU 및 XLA 디바이스에서 모델을 훈련·추론하기 위한 종합 가이드.

## 핵심 개념: Lazy Execution

XLA 디바이스에서 PyTorch 연산은 즉시 실행되지 않는다. 연산 그래프를 누적한 뒤 `torch_xla.sync()` (또는 `xm.mark_step()`) 호출 시점에 컴파일·실행된다. 이 지연 실행(lazy execution) 모델이 GPU와의 가장 큰 차이점이다.

- **트리거**: `torch_xla.sync()`, `MpDeviceLoader` (자동 호출), `loss.item()`, `print(tensor)` 등 값을 구체화(materialize)하는 모든 연산
- **주의**: 루프 내에서 `.item()`이나 `print`를 남발하면 매번 컴파일이 발생해 성능이 급격히 저하된다. 대신 `xm.add_step_closure()`를 사용한다.

---

## 1. 설치

```bash
# TPU VM (Python 3.10–3.13 지원)
pip install torch==2.9.0 'torch_xla[tpu]==2.9.0'

# Pallas 커스텀 커널 (선택)
pip install --pre torch_xla[pallas] \
  --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
  --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html

# 설치 확인
python -c "import torch_xla; print(torch_xla.__version__)"
python -c "import torch; print(torch.tensor(1.0, device='xla').device)"
```

> 2.7부터 C++11 ABI가 기본값. 2.9는 torch==2.9.0과 함께 사용해야 한다.

---

## 2. 기본 사용법

### 단일 디바이스

```python
import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = torch_xla.device()   # 또는 torch.device('xla')
model = MyModel().to(device)
optimizer = torch.optim.Adam(model.parameters())

for data, target in dataloader:
    # 권장: torch_xla.step() context manager (2.5+)
    with torch_xla.step():
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        loss = loss_fn(model(data), target)
        loss.backward()
        optimizer.step()
        # step() 컨텍스트를 벗어나면 자동으로 sync
```

### 멀티 디바이스 (torch_xla.launch — 권장)

```python
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

def train_fn(index):
    device = torch_xla.device()
    model = MyModel().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # MpDeviceLoader: 데이터를 디바이스에 프리페치 + mark_step 자동 호출
    mp_loader = pl.MpDeviceLoader(train_loader, device)

    for data, target in mp_loader:
        optimizer.zero_grad()
        loss = loss_fn(model(data), target)
        loss.backward()
        xm.optimizer_step(optimizer)  # allreduce + step + sync

if __name__ == '__main__':
    torch_xla.launch(train_fn, args=())
```

**`torch_xla.launch` vs `xmp.spawn`**: 2.5+에서는 `torch_xla.launch`가 공식 권장 API다. `xmp.spawn`은 하위 호환성을 위해 유지되나 신규 코드에서는 `launch`를 사용한다.

---

## 3. 분산 학습

자세한 내용은 `references/distributed.md` 참조.

### 3-1. DDP (DistributedDataParallel)

```python
import torch.distributed as dist
import torch_xla.distributed.xla_backend  # xla:// init_method 등록

dist.init_process_group("xla", init_method='xla://')
model = DDP(model.to('xla'), gradient_as_bucket_view=True)
```

### 3-2. SPMD (Single Program Multiple Data) — 대규모 LLM 권장

```python
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

xr.use_spmd()  # 반드시 다른 XLA 연산 이전에 호출

num_devices = xr.global_runtime_device_count()
mesh = xs.Mesh(
    device_ids=list(range(num_devices)),
    mesh_shape=(num_devices,),
    axis_names=('data',)
)

# 텐서 샤딩 어노테이션
xs.mark_sharding(model.weight, mesh, ('data', None))

# 입력 샤딩
xs.mark_sharding(input_tensor, mesh, ('data',))
```

### 3-3. FSDPv2 (SPMD 기반 FSDP)

```python
import numpy as np
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
    SpmdFullyShardedDataParallel as FSDPv2
)

xr.use_spmd()
num_devices = xr.global_runtime_device_count()
mesh = xs.Mesh(
    np.array(range(num_devices)).reshape(num_devices, 1),
    mesh_shape=(num_devices, 1),
    axis_names=('fsdp', 'model')
)

# gradient checkpointing은 FSDPv2 래핑 이전에 적용해야 함
from torch_xla.distributed.fsdp import checkpoint_module
model = FSDPv2(checkpoint_module(my_module), mesh)
```

### 3-4. HybridMesh (멀티슬라이스 TPU)

```python
from torch_xla.distributed.spmd import HybridMesh

# ici: intra-chip interconnect (슬라이스 내부)
# dcn: data center network (슬라이스 간)
mesh = HybridMesh(
    ici_mesh_shape=(8,),    # 슬라이스당 8 디바이스
    dcn_mesh_shape=(2,),    # 슬라이스 2개
    axis_names=('fsdp',)
)
```

---

## 4. 체크포인트 및 가중치 저장

자세한 내용은 `references/checkpointing.md` 참조.

### 4-1. 단일 디바이스 / DDP

```python
# 저장 — master_only=True 권장
xm.save(model.state_dict(), "model.pt", master_only=True)

# 최대 이식성이 필요한 경우 (비-XLA 환경에서도 로드 가능)
if xm.is_master_ordinal():
    cpu_state = {k: v.cpu() for k, v in model.state_dict().items()}
    torch.save(cpu_state, "model_cpu.pt")

# 로드
model.load_state_dict(torch.load("model_cpu.pt"))
model.to(device)
```

**주의**: `xm.save()`는 모든 프로세스에서 호출해야 한다. 마스터 프로세스만 호출하면 동기화 대기로 훈련이 멈춘다.

### 4-2. SPMD 분산 체크포인트

```python
import torch.distributed.checkpoint as dist_cp
import torch_xla.experimental.distributed_checkpoint as xc

# SPMD용 process group 초기화 (xla 백엔드 미지원, gloo 사용)
import torch.distributed as dist
import torch_xla.distributed.xla_backend
xr.use_spmd()
dist.init_process_group('gloo', init_method='xla://')

# 저장
state_dict = {"model": model.state_dict(), "optim": optim.state_dict()}
dist_cp.save(
    state_dict=state_dict,
    storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
    planner=xc.SPMDSavePlanner(),
)

# 로드 (모델이 이미 XLA 디바이스에 있고 샤딩이 적용된 상태여야 함)
state_dict = {"model": model.state_dict()}
dist_cp.load(
    state_dict=state_dict,
    storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
    planner=xc.SPMDLoadPlanner(),
)
model.load_state_dict(state_dict["model"])
```

### 4-3. CheckpointManager (고수준 API)

```python
from torch_xla.experimental.distributed_checkpoint import (
    CheckpointManager, prime_optimizer
)

# GCS 또는 로컬 경로에 10 스텝마다 체크포인트
chkpt_mgr = CheckpointManager('gs://my-bucket/experiment', save_interval=10)

# 복원
tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
    best_step = max(tracked_steps)
    prime_optimizer(optim)  # optimizer state 복원 전 반드시 호출
    state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
    chkpt_mgr.restore(best_step, state_dict)
    model.load_state_dict(state_dict['model'])
    optim.load_state_dict(state_dict['optim'])

# 훈련 루프 내
for step, (data, target) in enumerate(loader):
    # ... 훈련 코드 ...
    chkpt_mgr.save(step, {'model': model.state_dict()})
    # 또는 비동기 저장 (훈련 블로킹 없음)
    chkpt_mgr.save_async(step, {'model': model.state_dict()})
```

---

## 5. 훈련 파이프라인

### 표준 훈련 루프 패턴

```python
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
from torch_xla.amp import syncfree

def train_fn(index):
    device = torch_xla.device()

    model = MyModel().to(device)
    # AMP 사용 시 syncfree optimizer 사용 (device-host sync 제거)
    optimizer = syncfree.AdamW(model.parameters(), lr=1e-4)

    mp_loader = pl.MpDeviceLoader(
        train_loader,
        device,
        batches_per_execution=1,   # 증가 시 그래프 커짐, OOM 주의
        loader_prefetch_size=8,
        device_prefetch_size=4,
    )

    for epoch in range(num_epochs):
        model.train()
        for step, (data, target) in enumerate(mp_loader):
            optimizer.zero_grad()

            # AMP (TPU는 bfloat16 네이티브 지원, loss scaling 불필요)
            with torch.autocast('xla', dtype=torch.bfloat16):
                output = model(data)
                loss = loss_fn(output, target)

            loss.backward()
            xm.optimizer_step(optimizer)

            # loss 출력 시 step closure로 감싸기 (불필요한 recompile 방지)
            if step % 100 == 0:
                xm.add_step_closure(
                    lambda l: print(f"step={step}, loss={l:.4f}"),
                    args=(loss.detach(),)
                )

torch_xla.launch(train_fn)
```

### torch_xla.compile 사용 (2.5+)

```python
# Dynamo 기반 컴파일 — 초기 트레이싱 후 캐시 재사용
compiled_model = torch.compile(model, backend='openxla')

# full_graph 모드: 여러 컴파일 그래프 발생 시 에러 발생 (디버깅용)
compiled_model = torch.compile(model, backend='openxla',
                                options={'full_graph': True})
```

---

## 6. 데이터로더 파이프라인

```python
import torch
import torch_xla.distributed.parallel_loader as pl

# DistributedSampler로 각 프로세스에 다른 데이터 할당
train_sampler = torch.utils.data.distributed.DistributedSampler(
    dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True,
)

train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=train_sampler,
    num_workers=4,         # CPU 워커: 디바이스 수 이상, CPU 수 미만으로 설정
    drop_last=True,        # XLA는 정적 shape 선호 — drop_last=True 강력 권장
    prefetch_factor=2,
    persistent_workers=True,
)

# MpDeviceLoader: 백그라운드에서 데이터 업로드 + 자동 sync
mp_loader = pl.MpDeviceLoader(train_loader, device)

# SPMD 모드에서 샤딩 힌트 전달
mp_loader_spmd = pl.MpDeviceLoader(
    train_loader,
    device,
    input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))
)
```

**핵심 원칙**: XLA에서 shape 변화는 재컴파일을 유발한다. 가능한 모든 배치가 동일한 shape을 갖도록 `drop_last=True`와 패딩을 적극 활용한다.

---

## 7. Best Practices

**Lazy Execution 관련**
- `loss.item()`, `print(tensor)` 등 값을 직접 읽는 코드는 컴파일 트리거이므로 루프 내 최소화
- 텐서 값이 필요하면 `xm.add_step_closure()` 사용
- 정적 shape 유지: 동적 shape은 매번 재컴파일을 유발

**성능**
- TPU는 bfloat16 네이티브 지원 — `torch.autocast('xla', dtype=torch.bfloat16)` 사용, loss scaling 불필요
- AMP 사용 시 `torch_xla.amp.syncfree` 옵티마이저 사용 (SGD, Adam, AdamW)
- `MpDeviceLoader`는 항상 사용: 데이터 로딩과 디바이스 실행을 오버랩
- C++11 ABI 빌드(2.7+)는 lazy tensor tracing 성능이 더 좋음

**분산 학습**
- 모든 워커에서 동일한 랜덤 시드 사용 (정확도 일관성)
- `xm.save()`는 반드시 모든 프로세스에서 호출 (마스터만 호출 시 데드락)
- SPMD 모드에서 distributed checkpoint는 gloo process group 사용 (xla 백엔드 불가)
- FSDPv2에서 gradient checkpointing은 FSDPv2 래핑 이전에 적용

**디버깅**
```python
# 메트릭 출력 (느린 연산, fallback 확인)
import torch_xla.debug.metrics as met
print(met.metrics_report())

# 환경 변수
# TORCH_SHOW_CPP_STACKTRACES=1  : C++ 스택트레이스 출력 (구 XLA_SHOW_CPP_ERROR_CONTEXT)
# XLA_FLAGS=--xla_dump_to=/tmp/xla_dumps  : HLO 덤프
# PT_XLA_DEBUG_LEVEL=1          : 실행 원인 vs 컴파일 원인 구분
```

**재컴파일 방지**
- 루프 내 Python 분기(`if step == 0`) 최소화
- `batches_per_execution` 값을 높이면 그래프가 커지고 OOM 위험 증가
- shape이 변하는 연산(동적 padding 등)은 고정 shape으로 교체

---

## 8. 모듈 구조 참조

| 모듈 | 용도 |
|---|---|
| `torch_xla` | 최상위 API (`device()`, `sync()`, `step()`, `launch()`, `compile()`) |
| `torch_xla.core.xla_model` (xm) | 디바이스 관리, `mark_step`, `optimizer_step`, `save`, `get_ordinal` |
| `torch_xla.runtime` (xr) | `use_spmd()`, `global_runtime_device_count()`, `world_size()` |
| `torch_xla.distributed.spmd` (xs) | SPMD Mesh, `mark_sharding`, `HybridMesh` |
| `torch_xla.distributed.parallel_loader` (pl) | `MpDeviceLoader`, `ParallelLoader` |
| `torch_xla.amp.syncfree` | sync-free SGD, Adam, AdamW |
| `torch_xla.experimental.distributed_checkpoint` (xc) | `CheckpointManager`, `SPMDSavePlanner`, `SPMDLoadPlanner` |
| `torch_xla.debug.metrics` | 성능 메트릭, fallback 확인 |

---


## 레퍼런스

복잡한 사용 사례는 아래 파일 참조:
- `references/distributed.md` — DDP / SPMD / FSDPv2 심화
- `references/checkpointing.md` — CheckpointManager 심화, GCS 연동

### 서브 에이전트
요청의 복잡도에 따라서 서브 에이전트를 호출.

서브 에이전트는 아래 파일 참조:
- `agents/architect.md` — 설계 담당
- `agents/implementer.md` — 구현 담당
- `agents/reviewer_standard.md` — 리뷰 담당
