#!/bin/sh
#
# Copyright (C) 2026 Nikos Mavrogiannopoulos
#
# This file is part of ocserv.
#
# This file is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This file is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this file; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.

# Input is from environment:
#
# OCSERV_RESTRICT_TO_ROUTES: If set to '1' the user should be restricted
#                    to accessing the OCSERV_ROUTES and prevented from accessing
#                    OCSERV_NO_ROUTES.
#
# OCSERV_ROUTES:     A space separated list of IPv4 and IPv6 routes to
#                    which the user has access. If empty or not set the
#                    user has default route.
#
# OCSERV_ROUTES4:    A version of OCSERV_ROUTES with IPv4 addresses only.
# OCSERV_ROUTES6:    A version of OCSERV_ROUTES with IPv6 addresses only.
#
# OCSERV_NO_ROUTES:  A space separated list of IPv4 and IPv6 routes to
#                    which the user has NO access.
#
# OCSERV_NO_ROUTES4: A version of OCSERV_NO_ROUTES with IPv4 addresses only.
# OCSERV_NO_ROUTES6: A version of OCSERV_NO_ROUTES with IPv6 addresses only.
#
# OCSERV_DNS:        A space-separated list of DNS servers the user has access to.
# OCSERV_DNS4:       A version of OCSERV_DNS with IPv4 addresses only.
# OCSERV_DNS6:       A version of OCSERV_DNS with IPv6 addresses only.
#
# OCSERV_DENY_PORTS: A space-separated list of port types and ports that the user
#                    should be denied access to. An example of the format is:
#                    "tcp 443 udp 312 sctp 999 icmp all esp all icmpv6 all"
#
# OCSERV_ALLOW_PORTS: A space-separated list of port types and ports that the user
#                     should be granted access to. If set the user must be denied access
#                     to any other ports. An example of the format is:
#                     "tcp 443 udp 312 sctp 999 icmp all esp all icmpv6 all"

PATH=/sbin:/usr/sbin:$PATH

IPCALC=$(which ipcalc-ng 2>/dev/null)
if test -z "${IPCALC}"; then
	IPCALC=$(which ipcalc 2>/dev/null)
fi

if test -z "${IPCALC}"; then
	logger -t ocserv-fw "ipcalc or ipcalc-ng is required but not found"
	exit 1
fi

# nft table names cannot contain hyphens, dots, or other non-identifier chars
TABLE="ocserv_$(echo "${DEVICE}" | sed 's/[^a-zA-Z0-9_]/_/g')"

if test "$1" = "--removeall"; then
	_tables=$(nft list tables inet 2>/dev/null | sed -n 's/^table inet \(ocserv_[^ ]*\).*/\1/p')
	for t in $_tables; do
		nft delete table inet "$t" 2>/dev/null || true
	done
	exit 0
fi

execute_next_script() {
	if test -n "${OCSERV_NEXT_SCRIPT}"; then
		TMP_SCRIPT="${OCSERV_NEXT_SCRIPT}"
		unset OCSERV_NEXT_SCRIPT
		/bin/sh "${TMP_SCRIPT}"
	fi
}

if test "${REASON}" = "disconnect"; then
	nft delete table inet "${TABLE}" 2>/dev/null || true
	execute_next_script
	exit 0
fi

if test "${REASON}" != "connect"; then
	logger -t ocserv-fw "unknown reason ${REASON}"
	exit 1
fi

set -e

# Emit route restriction rules into the current nft chain context.
# When OCSERV_RESTRICT_TO_ROUTES=1: denied routes are rejected, allowed routes
# are accepted, and traffic not matching any route is rejected or accepted
# depending on whether an explicit allow-list (OCSERV_ROUTES) was given.
# Without route restriction, all traffic is accepted.
emit_route_rules() {
	if test "${OCSERV_RESTRICT_TO_ROUTES}" = "1"; then
		if test -n "${OCSERV_NO_ROUTES4}"; then
			no4=$(routes_to_nft $OCSERV_NO_ROUTES4)
			printf '    iif "%s" ip daddr { %s } reject\n' "${DEVICE}" "${no4}"
		fi
		if test -n "${OCSERV_NO_ROUTES6}"; then
			no6=$(routes_to_nft $OCSERV_NO_ROUTES6)
			printf '    iif "%s" ip6 daddr { %s } reject\n' "${DEVICE}" "${no6}"
		fi
		if test -n "$OCSERV_ROUTES"; then
			if test -n "$OCSERV_ROUTES4"; then
				r4=$(routes_to_nft $OCSERV_ROUTES4)
				printf '    iif "%s" ip daddr { %s } accept\n' "${DEVICE}" "${r4}"
			fi
			if test -n "$OCSERV_ROUTES6"; then
				r6=$(routes_to_nft $OCSERV_ROUTES6)
				printf '    iif "%s" ip6 daddr { %s } accept\n' "${DEVICE}" "${r6}"
			fi
			printf '    iif "%s" reject\n' "${DEVICE}"
		else
			printf '    iif "%s" accept\n' "${DEVICE}"
		fi
	else
		printf '    iif "%s" accept\n' "${DEVICE}"
	fi
}

# Convert a route that may use a dotted-decimal subnet mask
# (e.g. 10.0.0.0/255.0.0.0, as ocserv normalises IPv4 routes) to
# CIDR prefix-length notation (10.0.0.0/8) required by nftables.
# IPv6 routes and routes already in CIDR notation are passed through.
normalize_route() {
	case "$1" in
		*/*.*.*.*) eval $(${IPCALC} -p "$1"); echo "${1%%/*}/${PREFIX}" ;;
		*)         echo "$1" ;;
	esac
}

# Convert a space-separated list of routes to a comma-separated list
# in CIDR notation for use in nftables inline sets.
routes_to_nft() {
	_sep=""
	for _r in $@; do
		printf '%s%s' "${_sep}" "$(normalize_route "${_r}")"
		_sep=","
	done
	echo ""
}

# Remove any leftover table for this device (must be outside the atomic block)
nft delete table inet "${TABLE}" 2>/dev/null || true

# Build and apply the complete table definition atomically
{
	printf 'table inet %s {\n  chain ocserv_fwd {\n' "${TABLE}"
	# priority filter-10 ensures ocserv rules evaluate before firewalld (priority 0)
	printf '    type filter hook forward priority filter - 10; policy accept;\n'
	printf '    oif "%s" ct state established,related accept\n' "${DEVICE}"

	# DNS — collect addresses into a single inline set per address family
	if test -n "${OCSERV_DNS4}"; then
		dns4=$(echo $OCSERV_DNS4 | sed 's/ /,/g')
		printf '    iif "%s" ip daddr { %s } udp dport 53 ct state new accept\n' \
			"${DEVICE}" "${dns4}"
		printf '    iif "%s" ip daddr { %s } tcp dport 53 ct state new,established accept\n' \
			"${DEVICE}" "${dns4}"
	fi
	if test -n "${OCSERV_DNS6}"; then
		dns6=$(echo $OCSERV_DNS6 | sed 's/ /,/g')
		printf '    iif "%s" ip6 daddr { %s } udp dport 53 ct state new accept\n' \
			"${DEVICE}" "${dns6}"
		printf '    iif "%s" ip6 daddr { %s } tcp dport 53 ct state new,established accept\n' \
			"${DEVICE}" "${dns6}"
	fi

	# Port restrictions
	if test -n "${OCSERV_DENY_PORTS}"; then
		# Denied ports are rejected directly; route restriction follows for
		# the remaining traffic.
		set -- ${OCSERV_DENY_PORTS}
		while test $# -gt 1; do
			proto=$1
			port=$2
			case "$proto" in
				icmp)   printf '    iif "%s" ip protocol icmp reject\n'   "${DEVICE}" ;;
				icmpv6) printf '    iif "%s" ip6 nexthdr icmpv6 reject\n' "${DEVICE}" ;;
				esp)    printf '    iif "%s" meta l4proto esp reject\n'   "${DEVICE}" ;;
				*)      printf '    iif "%s" meta l4proto %s th dport %s reject\n' \
					"${DEVICE}" "$proto" "$port" ;;
			esac
			shift 2
		done
	elif test -n "${OCSERV_ALLOW_PORTS}"; then
		# Allowed ports jump to the route-restriction chain; everything else
		# is rejected.  This ensures route restriction (which restrict-user-to-ports
		# implies) is evaluated for each allowed port.
		set -- ${OCSERV_ALLOW_PORTS}
		while test $# -gt 1; do
			proto=$1
			port=$2
			case "$proto" in
				icmp)   printf '    iif "%s" ip protocol icmp jump ocserv_rt_%s\n'   "${DEVICE}" "${TABLE}" ;;
				icmpv6) printf '    iif "%s" ip6 nexthdr icmpv6 jump ocserv_rt_%s\n' "${DEVICE}" "${TABLE}" ;;
				esp)    printf '    iif "%s" meta l4proto esp jump ocserv_rt_%s\n'   "${DEVICE}" "${TABLE}" ;;
				*)      printf '    iif "%s" meta l4proto %s th dport %s jump ocserv_rt_%s\n' \
					"${DEVICE}" "$proto" "$port" "${TABLE}" ;;
			esac
			shift 2
		done
		printf '    iif "%s" reject\n' "${DEVICE}"
	fi

	# Route restrictions for DENY_PORTS and no-port-restriction cases.
	# For ALLOW_PORTS the route restrictions live in the ocserv_rt_${TABLE} chain below.
	if test -z "${OCSERV_ALLOW_PORTS}"; then
		emit_route_rules
	fi

	printf '  }\n'

	# Named chain jumped to by ALLOW_PORTS rules; applies route restriction
	# so that both port and route policies are enforced simultaneously.
	if test -n "${OCSERV_ALLOW_PORTS}"; then
		printf '  chain ocserv_rt_%s {\n' "${TABLE}"
		emit_route_rules
		printf '  }\n'
	fi

	printf '}\n'
} | nft -f -

execute_next_script

exit 0
