Running LLMs on Strix Halo
This is a guide on how to start running LLMs on AMD Strix Halo
The full code and all scripts are in the github repo.
Some excellent resources:
– this AMD Strix Halo guide
– and the accompanying repo
– this repo contains some useful standalone built components
Tips on OS installation
I recently got my AMD AI Max 395+ from Framework. It has a bunch of excellent guides on how to set up your machine (including what OS to use) that you can follow on the Frameworks docs here.
I chose to go with Fedora for OS because it seemed to be the most solidly supported Linux distro (for Framework computers) and I mainly want to use the computer to run LLMs.
Fix Preliminary Memory Stuff
By default not all of the iGPU memory is available for PyTorch.
Because Strix Halo’s iGPU uses UMA (system RAM via GTT), the amdgpu/TTM memory manager enforces conservative page-pool limits that can fragment or block large ROCm/HIP allocations, so bumping ttm’s pages_limit/page_pool_size (and rebuilding initramfs) lets PyTorch actually allocate and use most of the shared “VRAM.”
ls /sys/module/ttm 2>/dev/null && echo "Confirming we are using ttm"
modinfo -p ttm 2>/dev/null || true
sudo tee /etc/modprobe.d/ttm.conf >/dev/null <<'EOF'
options ttm pages_limit=31457280 page_pool_size=31457280
EOF
sudo dracut -f --regenerate-all
sudo reboot
Running Torch and Transformers
If you want to be able to play around with models and torch, the easiest way I found is to get the relevant torch and ROCM versions from The Rock — which is a supposed to be a built platform to unify the historically fiddly experience of patching and building all necessary components for AMD hardware.
I found something like this to work out pretty nicely of the box:
mkdir test && cd test
uv venv
uv pip install \
--index-url https://rocm.nightlies.amd.com/v2/gfx1151/ \
"rocm[libraries,devel]"
uv pip install \
--index-url https://rocm.nightlies.amd.com/v2/gfx1151/ \
--pre torch
uv pip install transformers accelerate
uv run python <<'PY'
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
print("Loading Qwen/Qwen3-1.7B model...")
model_name = "Qwen/Qwen3-1.7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.bfloat16,
).to("cuda:0")
print("\nRunning inference...")
prompt = "What is the capital of Bulgaria?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_p=0.95
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nPrompt: {prompt}")
print(f"Response: {response}")
PY
vLLM
Below is a slightly hacky way to build vLLM with ROCM for Strix Halo.
Note that a bunch of useful features like HIP graphs do not work with this so I will probably come back and see if it can be done better.
# Customise via env
INDEX_URL="${INDEX_URL:-https://rocm.nightlies.amd.com/v2/gfx1151/}"
VLLM_REPO="${VLLM_REPO:-https://github.com/vllm-project/vllm.git}"
VLLM_BRANCH="${VLLM_BRANCH:-main}"
VLLM_DIR="${VLLM_DIR:-vllm}"
GPU_ARCH="${GPU_ARCH:-gfx1151}"
BUILD_TYPE="${BUILD_TYPE:-Release}"
echo "=== ROCm TheRock + vLLM Build ==="
mkdir -p test && cd test
uv venv
source ./.venv/bin/activate
# Core deps + ROCm toolchain
uv pip install --index-url "$INDEX_URL" "rocm[libraries,devel]"
uv pip install ninja cmake wheel pybind11
uv pip install --index-url "$INDEX_URL" torch torchvision
# Fix numpy version compatibility (numba requires <2.2)
uv pip install "numpy>=1.26,<2.2"
# Misc dependencies
uv pip install --upgrade numba scipy huggingface-hub[cli]
# Discover ROCm paths from the venv
ROCM_ROOT="$(python - <<'PY'
import subprocess,sys; print(subprocess.check_output(
[sys.executable,"-P","-m","rocm_sdk","path","--root"], text=True).strip())
PY
)"
ROCM_BIN="$(python - <<'PY'
import subprocess,sys; print(subprocess.check_output(
[sys.executable,"-P","-m","rocm_sdk","path","--bin"], text=True).strip())
PY
)"
LLVM_BIN="${ROCM_ROOT}/lib/llvm/bin"
ROCM_CMAKE="${ROCM_ROOT}/lib/cmake"
ROCM_BC="${ROCM_ROOT}/lib/llvm/amdgcn/bitcode"
# Export so CMake/Clang find ROCm inside the venv (no /opt/rocm)
export HIP_PLATFORM=amd
export ROCM_PATH="$ROCM_ROOT"
export HIP_PATH="$ROCM_ROOT"
export HIP_CLANG_PATH="$LLVM_BIN"
export HIP_DEVICE_LIB_PATH="$ROCM_BC"
export PATH="$ROCM_BIN:$LLVM_BIN:$PATH"
export LD_LIBRARY_PATH="${ROCM_ROOT}/lib:${ROCM_ROOT}/lib64:${LD_LIBRARY_PATH:-}"
export CMAKE_PREFIX_PATH="${ROCM_CMAKE}:${CMAKE_PREFIX_PATH:-}"
export AMDGPU_TARGETS="${GPU_ARCH}"
export GPU_TARGETS="${GPU_ARCH}"
export PYTORCH_ROCM_ARCH="${GPU_ARCH}"
export TRITON_HIP_LLD_PATH="${LLVM_BIN}/ld.lld"
export HIP_VISIBLE_DEVICES="0"
export AMD_SERIALIZE_KERNEL="3"
# Clone / update vLLM
if [[ ! -d "$VLLM_DIR" ]]; then
git clone --branch "$VLLM_BRANCH" "$VLLM_REPO" "$VLLM_DIR"
else
(cd "$VLLM_DIR" && git fetch origin "$VLLM_BRANCH" && git reset --hard "origin/$VLLM_BRANCH")
fi
cd "$VLLM_DIR"
# Run use_existing_torch.py to configure for our PyTorch build
python use_existing_torch.py
# Apply gfx1151 patches
echo "=== Applying gfx1151 Patches ==="
# Add gfx1151 to CMakeLists.txt
echo "Adding gfx1151 to CMakeLists.txt..."
if ! grep -q "gfx1151" CMakeLists.txt; then
sed -i 's/set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")/set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1151;gfx1200;gfx1201")/' CMakeLists.txt
echo "✅ Added gfx1151 to CMakeLists.txt"
else
echo "✅ gfx1151 already present in CMakeLists.txt"
fi
# Remove torch dependency from pyproject.toml
echo "Removing torch dependency from pyproject.toml..."
sed -i '/torch == 2.8.0,/d' pyproject.toml || true
echo "✅ Removed torch dependency from pyproject.toml"
# Remove amdsmi from rocm-build.txt (causes segfaults on gfx1151)
echo "Removing amdsmi from requirements..."
sed -i '/amdsmi==/d' requirements/rocm-build.txt || true
uv pip install -r requirements/rocm-build.txt
# Patch ROCm platform detection to use torch instead of amdsmi
echo "Patching ROCm platform detection..."
if ! grep -q "torch.version.hip" vllm/platforms/__init__.py; then
# Replace amdsmi-based detection with torch-based detection
sed -i '/def rocm_platform_plugin/,/return "vllm.platforms.rocm.RocmPlatform"/c\
def rocm_platform_plugin() -> str | None:\
import torch\
is_rocm = hasattr(torch, "version") and hasattr(torch.version, "hip") and torch.version.hip\
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None' vllm/platforms/__init__.py
echo "✅ Patched ROCm platform detection"
else
echo "✅ ROCm platform detection already patched"
fi
echo "✅ All gfx1151 patches applied"
# Build vLLM with ROCm target device
echo ""
echo "=== Building vLLM ==="
# Get current torch version for constraints
TORCH_VERSION=$(python -c "import torch; print(torch.__version__)" 2>/dev/null || echo "unknown")
echo "Current PyTorch version: $TORCH_VERSION"
# Create constraints file to preserve ROCm torch
echo "torch==$TORCH_VERSION" > /tmp/vllm-constraints.txt
echo "Installing vLLM with constraints to preserve ROCm torch..."
VLLM_TARGET_DEVICE=rocm uv pip install -e . --no-build-isolation --constraint /tmp/vllm-constraints.txt
rm -f /tmp/vllm-constraints.txt
# Final numpy version check
uv pip install "numpy>=1.26,<2.2"
cd ..
echo
echo "✅ vLLM built ok"
echo "=== Creating utility to activate environment ==="
cat > ./.venv/bin/activate-rocm <<'EOF'
#!/usr/bin/env bash
ROCM_ROOT="$(python - <<'PY'
import subprocess,sys
print(subprocess.check_output([sys.executable,"-P","-m","rocm_sdk","path","--root"], text=True).strip())
PY
)"
ROCM_BIN="$(python - <<'PY'
import subprocess,sys
print(subprocess.check_output([sys.executable,"-P","-m","rocm_sdk","path","--bin"], text=True).strip())
PY
)"
LLVM_BIN="${ROCM_ROOT}/lib/llvm/bin"
export HIP_PLATFORM=amd
export ROCM_PATH="$ROCM_ROOT"
export HIP_PATH="$ROCM_ROOT"
export HIP_CLANG_PATH="$LLVM_BIN"
export HIP_DEVICE_LIB_PATH="$ROCM_ROOT/lib/llvm/amdgcn/bitcode"
export PATH="$ROCM_BIN:$LLVM_BIN:$PATH"
export LD_LIBRARY_PATH="$ROCM_ROOT/lib:$ROCM_ROOT/lib64:${LD_LIBRARY_PATH:-}"
export AMDGPU_TARGETS="${AMDGPU_TARGETS:-gfx1151}"
export GPU_TARGETS="${GPU_TARGETS:-gfx1151}"
export HIP_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES:-0}"
export AMD_SERIALIZE_KERNEL="3"
export TRITON_HIP_LLD_PATH="$LLVM_BIN/ld.lld"
echo "✅ ROCm env ready at: $ROCM_ROOT"
EOF
chmod +x ./.venv/bin/activate-rocm
echo "✅ all set -- run 'source test/.venv/bin/activate && source test/.venv/bin/activate-rocm' to use your build"
echo ""
echo "To test vLLM:"
echo " python -c \"import vllm; print('vLLM version:', vllm.__version__)\""
echo ""
echo "To run vLLM server (use --enforce-eager to disable torch.compile): "
echo "vllm serve Qwen/Qwen3-1.7B --gpu-memory-utilization 0.75 --enforce-eager"
Running vLLM server:
# Activate environment
source test/.venv/bin/activate && source test/.venv/bin/activate-rocm
# Start server (--enforce-eager is required to disable torch.compile which has issues)
vllm serve Qwen/Qwen3-1.7B --gpu-memory-utilization 0.75 --enforce-eager
# Test with curl
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-1.7B",
"prompt": "What is the capital of Bulgaria?",
"max_tokens": 50
}'
llama.cpp
There's two backends you can use for llama.cpp — Vulkan and ROCM
Building llama.cpp with Vulcan:
# Customise via env
BUILD_TYPE="${BUILD_TYPE:-Release}"
# export AMD_VULKAN_ICD=RADV
echo "=== Vulkan + llama.cpp Build ==="
# Install dependencies
echo "Installing dependencies..."
sudo dnf -y install git cmake ninja-build pkgconf-pkg-config libcurl-devel
sudo dnf -y install vulkan-tools vulkan-loader mesa-vulkan-drivers \
vulkan-loader-devel vulkan-headers
sudo dnf -y install glslc glslang spirv-tools python3 python3-pip
# Clone llama.cpp
git clone https://github.com/ggml-org/llama.cpp/ llama.cpp-vulkan
cd llama.cpp-vulkan
rm -rf build
# Configure with Vulkan support
cmake -S . -B build -G Ninja \
-DGGML_VULKAN=ON \
-DLLAMA_BUILD_SERVER=ON \
-DLLAMA_BUILD_EXAMPLES=ON \
-DCMAKE_BUILD_TYPE="$BUILD_TYPE"
# Build
ninja -C build
Building with ROCM:
# Customise via env
INDEX_URL="${INDEX_URL:-https://rocm.nightlies.amd.com/v2/gfx1151/}"
LLAMA_REPO="${LLAMA_REPO:-https://github.com/ROCm/llama.cpp.git}"
LLAMA_BRANCH="${LLAMA_BRANCH:-amd-integration}"
LLAMA_DIR="${LLAMA_DIR:-llama.cpp-rocm}"
GPU_ARCH="${GPU_ARCH:-gfx1151}"
BUILD_TYPE="${BUILD_TYPE:-Release}"
echo "=== ROCm + llama.cpp ==="
mkdir -p test && cd test
uv venv
source ./.venv/bin/activate
# ROCm toolchain (venv-local) + build helpers
uv pip install --index-url "$INDEX_URL" "rocm[libraries,devel]"
uv pip install ninja cmake
# Discover ROCm paths from the venv
ROCM_ROOT="$(python - <<'PY'
import subprocess,sys; print(subprocess.check_output(
[sys.executable,"-P","-m","rocm_sdk","path","--root"], text=True).strip())
PY
)"
ROCM_BIN="$(python - <<'PY'
import subprocess,sys; print(subprocess.check_output(
[sys.executable,"-P","-m","rocm_sdk","path","--bin"], text=True).strip())
PY
)"
LLVM_BIN="${ROCM_ROOT}/lib/llvm/bin"
ROCM_CMAKE="${ROCM_ROOT}/lib/cmake"
ROCM_BC="${ROCM_ROOT}/lib/llvm/amdgcn/bitcode"
# Export so CMake/Clang find ROCm inside the venv (no /opt/rocm)
export HIP_PLATFORM=amd
export ROCM_PATH="$ROCM_ROOT"
export HIP_PATH="$ROCM_ROOT"
export HIP_CLANG_PATH="$LLVM_BIN"
export HIP_DEVICE_LIB_PATH="$ROCM_BC"
export PATH="$ROCM_BIN:$LLVM_BIN:$PATH"
export LD_LIBRARY_PATH="${ROCM_ROOT}/lib:${ROCM_ROOT}/lib64:${LD_LIBRARY_PATH:-}"
export CMAKE_PREFIX_PATH="${ROCM_CMAKE}:${CMAKE_PREFIX_PATH:-}"
export AMDGPU_TARGETS="${GPU_ARCH}"
export GPU_TARGETS="${GPU_ARCH}"
# Clone / update llama.cpp
if [[ ! -d "$LLAMA_DIR" ]]; then
git clone --depth=1 --branch "$LLAMA_BRANCH" "$LLAMA_REPO" "$LLAMA_DIR"
else
(cd "$LLAMA_DIR" && git fetch origin "$LLAMA_BRANCH" && git reset --hard "origin/$LLAMA_BRANCH")
fi
cd "$LLAMA_DIR"
rm -rf build
cmake -S . -B build -G Ninja \
-DGGML_HIP=ON \
-DCMAKE_BUILD_TYPE="$BUILD_TYPE" \
-DCMAKE_C_COMPILER="${LLVM_BIN}/clang" \
-DCMAKE_CXX_COMPILER="${LLVM_BIN}/clang++" \
-DCMAKE_HIP_ARCHITECTURES="${GPU_ARCH}" \
-DCMAKE_HIP_FLAGS="--rocm-path=${ROCM_ROOT} --rocm-device-lib-path=${ROCM_BC}"
cmake --build build -- -j"$(nproc)"
and you can use the following utility to create be able to source the environment:
cat > test/.venv/bin/activate-rocm <<'EOF'
#!/usr/bin/env bash
ROCM_ROOT="$(python - <<'PY'
import subprocess,sys
print(subprocess.check_output([sys.executable,"-P","-m","rocm_sdk","path","--root"], text=True).strip())
PY
)"
export ROCM_PATH="$ROCM_ROOT"
export HIP_PATH="$ROCM_ROOT"
export PATH="$ROCM_ROOT/bin:$ROCM_ROOT/lib/llvm/bin:$PATH"
export LD_LIBRARY_PATH="$ROCM_ROOT/lib:$ROCM_ROOT/lib64:${LD_LIBRARY_PATH:-}"
export HIP_DEVICE_LIB_PATH="$ROCM_ROOT/lib/llvm/amdgcn/bitcode"
export AMDGPU_TARGETS="${AMDGPU_TARGETS:-gfx1151}"
export GPU_TARGETS="${GPU_TARGETS:-gfx1151}"
export HIP_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES:-0}"
echo "✓ ROCm env ready at: $ROCM_ROOT"
EOF
Then we get the following performance as measured by
./build/bin/llama-bench -m \
$PATH_TO_GPT_OSS_20_B \
-ngl 999 \
-t 8 \
-pg 256,128 \
-b 512 \
-ub 256 \
-fa 1
Vulkan:
| model | size | params | backend | ngl | threads | n_batch | n_ubatch | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | Vulkan | 999 | 8 | 512 | 256 | 1 | pp512 | 1040.68 ± 19.03 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | Vulkan | 999 | 8 | 512 | 256 | 1 | pp256+tg128 | 192.56 ± 0.58 |
ROCM:
| model | size | params | backend | ngl | threads | n_batch | n_ubatch | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | ROCm | 999 | 8 | 512 | 256 | 1 | pp512 | 1178.27 ± 7.69 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | ROCm | 999 | 8 | 512 | 256 | 1 | tg128 | 66.95 ± 0.01 |
| gpt-oss 20B MXFP4 MoE | 11.27 GiB | 20.91 B | ROCm | 999 | 8 | 512 | 256 | 1 | pp256+tg128 | 178.87 ± 0.20 |