#!/usr/bin/env python3
#
# Copyright (c) 2003-2007 Andrea Luzzardi <scox@sig11.org>
#
# This file is part of the pam_usb project. pam_usb is free software;
# you can redistribute it and/or modify it under the terms of the GNU General
# Public License version 2, as published by the Free Software Foundation.
#
# pam_usb is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc., 51 Franklin
# Street, Fifth Floor, Boston, MA 02110-1301 USA.

import fcntl
import getopt
import gi
import os
import pwd
import re
import signal
import subprocess
import sys
import syslog
import threading
import datetime
import dbus
import time
from dbus.mainloop.glib import DBusGMainLoop

DBusGMainLoop(set_as_default=True)

gi.require_version('UDisks', '2.0')

from gi.repository import GLib, UDisks

import xml.etree.ElementTree as et

class HotPlugDevice:
	def __init__(self, serial, name):
		self.__udi = None
		self.__serial = serial
		self.__name = name
		self.__callbacks = []
		self.__running = False

	def run(self):
		self.__scanDevices()
		self.__registerSignals()
		self.__running = True
		GLib.MainLoop().run()
		print('signals registered')

	def addCallback(self, callback):
		self.__callbacks.append(callback)

	def isDeviceConnected(self):
		if self.__udi is not None:
			return True

		return False

	def getWatchedDeviceName(self):
		return self.__name

	def __scanDevices(self):
		for udi in udisksObjectManager.get_objects():
			if udi.get_block():
				device = udisks.get_drive_for_block(udi.get_block())
				if device:
					self.__deviceAdded(device)

	def __registerSignals(self):
		for signal, callback in (('object-added', self.__objectAdded), ('object-removed', self.__objectRemoved)):
			udisksObjectManager.connect(signal, callback)

	def __objectAdded(self, _, udi):
		if udi.get_block():
			device = udisks.get_drive_for_block(udi.get_block())
			if device:
				self.__deviceAdded(device)

	def __objectRemoved(self, _, udi):
		if udi.get_block():
			device = udisks.get_drive_for_block(udi.get_block())
			if device:
				self.__deviceRemoved(device)

	def __deviceAdded(self, udi):
		if self.__udi is not None:
			return
		if udi.get_property('serial') != self.__serial:
			return
		self.__udi = udi
		if self.__running:
			[ cb('added', self.__name) for cb in self.__callbacks ]

	def __deviceRemoved(self, udi):
		if self.__udi is None:
			return
		if self.__udi != udi:
			return
		self.__udi = None
		if self.__running:
			[ cb('removed', self.__name) for cb in self.__callbacks ]

class Log:
	def __init__(self):
		syslog.openlog('pamusb-agent', syslog.LOG_PID | syslog.LOG_PERROR, syslog.LOG_AUTH)

	def info(self, message):
		self.__logMessage(syslog.LOG_NOTICE, message)

	def error(self, message):
		self.__logMessage(syslog.LOG_ERR, message)

	def __logMessage(self, priority, message):
		syslog.syslog(priority, message)

def usage():
	print('Usage: %s [--help] [--config=path] [--daemon] [--check=path]' % os.path.basename(__file__))
	sys.exit(1)

def runAs(uid, gid):
	def set_id():
		os.setgid(gid)
		os.setuid(uid)
	return set_id

import getopt

try:
	opts, args = getopt.getopt(sys.argv[1:], "hc:dc:", ["help", "config=", "daemon", "check="])
except getopt.GetoptError:
	usage()

options = {
	'configFile': '/etc/security/pam_usb.conf',
	'daemon': False,
	'check': '/usr/bin/pamusb-check'
}

if len(args) != 0:
	usage()

for o, a in opts:
	if o in ('-h', '--help'):
		usage()
	if o in ('-c', '--config'):
		options['configFile'] = a
	if o in ('-d', '--daemon'):
		options['daemon'] = True
	if o in ('-c', '--check'):
		options['check'] = a


if not os.path.exists(options['check']):
	print('%s not found.' % options['check'])
	print("You might specify manually pamusb-check's location using --check.")
	usage()

logger = Log()

doc = et.parse(options['configFile'])
users = doc.findall('users/user')

def login1ManagerDBusIface():
	system_bus = dbus.SystemBus()
	proxy = system_bus.get_object('org.freedesktop.login1', '/org/freedesktop/login1')
	login1 = dbus.Interface(proxy, 'org.freedesktop.login1.Manager')
	return login1

def userDeviceThread(user):
	userName = user.get('id')
	uid = pwd.getpwnam(userName)[2]
	gid = pwd.getpwnam(userName)[3]
	os.environ = None

	events = {
		'lock': [],
		'unlock': []
	}

	for hotplug in user.findall('agent'):
		henvs = {}
		hcmds = []

		for hcmd in hotplug.findall('cmd'):
			if hcmd.text is not None:
				hcmds.append(hcmd.text)
			else:
				logger.error('Ignoring empty command for user "%s".' % userName)

		for henv in hotplug.findall('env'):
			if henv.text is not None:
				henv_var = re.sub(r'^(.*?)=.*$', '\\1', henv.text)
				henv_arg = re.sub(r'^.*?=(.*)$', '\\1', henv.text)

				if henv_var != '' and henv_arg != '':
					henvs[henv_var] = henv_arg
				else:
					logger.error('Ignoring invalid command environment variable for user "%s".' % userName)
			else:
				logger.error('Ignoring empty environment variable for user "%s".' % userName)

		events[hotplug.get('event')].append(
			{
				'env': henvs,
				'cmd': hcmds
			}
		)

	devices_for_user = []
	to_watch = []

	all_devices = doc.findall("devices/device")
	user_devices = user.findall("device")
	for device in user_devices:
		devices_for_user.append(device.text)

	deviceOK = False
	for device in all_devices:
		if device.get('id') in devices_for_user:
			to_watch.append({"name": device.get('id'), "serial": device.findtext('serial')})
			deviceOK = True

	if not deviceOK or len(to_watch) == 0:
		logger.error('Device(s) not found in configuration file.')
		return 1

	resumeTimestamp = datetime.datetime.min

	def authChangeCallback(event, deviceName):
		nonlocal hpDevs

		for otherDeviceThread in hpDevs:
			if not otherDeviceThread.getWatchedDeviceName() == deviceName and otherDeviceThread.isDeviceConnected():
				logger.info('Device "%s" removed or plugged but another one is connected anyway, ignoring' % deviceName)
				return

		if event == 'removed':
			nonlocal resumeTimestamp
			currentTimestamp = datetime.datetime.now()
			difference = currentTimestamp - resumeTimestamp
			if difference.total_seconds() < 5.0:
				logger.info('Device is removed during suspend, ignoring')
				return

			logger.info('Device "%s" has been removed, locking down user "%s"...' % (deviceName, userName))

			for l in events['lock']:
				if len(l['cmd']) != 0:
					for cmd in l['cmd']:
						logger.info('Running "%s" (as %d:%d)' % (cmd, uid, gid))
						process = subprocess.run(['sh', '-c', cmd], env=l['env'], preexec_fn=runAs(uid, gid), capture_output=True)
						logger.info('Process exit code: %d' % (process.returncode))
						logger.info('Process stdout: %s' % (process.stdout.decode()))
						logger.info('Process stderr: %s' % (process.stderr.decode()))

				else:
					logger.info('No commands defined for lock!')

			logger.info('Locked.')
			return


		logger.info('Device "%s" has been inserted. Performing verification...' % deviceName)

		cmdLine = "%s --debug --config=%s --service=pamusb-agent %s" % (options['check'], options['configFile'], userName)
		logger.info('Executing "%s"' % cmdLine)
		if not os.system(cmdLine):
			logger.info('Authentication succeeded. Unlocking user "%s"...' % userName)

			for l in events['unlock']:
				if len(l['cmd']) != 0:
					for cmd in l['cmd']:
						logger.info('Running "%s" (as %d:%d)' % (cmd, uid, gid))
						process = subprocess.run(['sh', '-c', cmd], env=l['env'], preexec_fn=runAs(uid, gid), capture_output=True)
						logger.info('Process exit code: %d' % (process.returncode))
						logger.info('Process stdout: %s' % (process.stdout.decode()))
						logger.info('Process stderr: %s' % (process.stderr.decode()))

				else:
					logger.info('No commands defined for unlock!')

			logger.info('Unlocked.')
			return

		else:
			logger.info('Authentication failed for device %s. Keeping user "%s" locked down.' % (deviceName, userName))

	def onSuspendOrResume(start, member=None):
		nonlocal resumeTimestamp
		nonlocal hpDevs

		for hpDev in hpDevs:
			if start == True:
				logger.info('Suspending user "%s"' % (userName))
				resumeTimestamp = datetime.datetime.max
			else:
				logger.info('Resuming user "%s"' % (userName))
				if hpDev.isDeviceConnected() == True:
					logger.info('Device %s is connected for user "%s", unlocking' % (hpDev.__name, userName))
					authChangeCallback('added')

				resumeTimestamp = datetime.datetime.now()

	login1Interface = login1ManagerDBusIface()
	for signal in ['PrepareForSleep', 'PrepareForShutdown']:
		login1Interface.connect_to_signal(signal, onSuspendOrResume, member_keyword='member')

	hpDevs = []
	threads = []
	for watch_this in to_watch:
		logger.info('Watching device "%s" for user "%s"' % (watch_this.get('name'), userName))

		hpDev = HotPlugDevice(watch_this.get('serial'), watch_this.get('name'))
		hpDev.addCallback(authChangeCallback)

		thread = threading.Thread(target=hpDev.run)
		thread.start()

		threads.append(thread)
		hpDevs.append(hpDev)

udisks = UDisks.Client.new_sync()
udisksObjectManager = udisks.get_object_manager()

sysUsers = []
validUsers = []

def processCheck():
	global filelock
	filelock=open(os.path.realpath(__file__),'r')

	try:
		fcntl.flock(filelock,fcntl.LOCK_EX|fcntl.LOCK_NB)
	except:
		logger.error('Process is already running.')
		sys.exit(1)

	if os.getuid() != 0:
		logger.error('Process must be run as root.')
		sys.exit(1)

processCheck()

try:
	with open('/etc/passwd', 'r') as f:
		for line in f.readlines():
			sysUser = re.sub(r'^(.*?):.*', '\\1', line[:-1])
			sysUsers.append(sysUser)
		f.close()
except:
	logger.error('Couldn\'t read system user names from "/etc/passwd". Process can\'t continue.')
	sys.exit(1)

logger.info('pamusb-agent up and running.')

for userObj in users:
	userId = userObj.get('id')

	for sysUser_ in sysUsers:
		if (userId == sysUser_ and userObj not in validUsers):
			validUsers.append(userObj)

# logger.error('User %s not found in configuration file' % username)

for user in validUsers:
	threading.Thread(
		target=userDeviceThread,
		args=(user,)
	).start()

if options['daemon'] and os.fork():
	sys.exit(0)

def sig_handler(sig, frame):
	logger.info('Stopping agent.')
	sys.exit(0)

sys_signals = ['SIGINT', 'SIGTERM', 'SIGTSTP', 'SIGTTIN', 'SIGTTOU']

for i in sys_signals:
	signal.signal(getattr(signal, i), sig_handler)
