/*
 *      NiceShaper - Dynamic Traffic Management
 *
 *      Copyright 2004 Mariusz Jedwabny <mariusz@jedwabny.net>
 *
 *      This file is subject to the terms and conditions of the GNU General Public
 *      License.  See the file COPYING in the main directory of this archive for
 *      more details.
 */

#include "filter.h"

#include <cstring>
#include <arpa/inet.h>
#include <stdlib.h>
#include <limits.h>

#include <string>
#include <vector>
#include <iostream>

#include "main.h"
#include "aux.h"
#include "log.h"
#include "sys.h"
#include "ifaces.h"
#include "tests.h"

using std::string;
using std::vector;
using namespace aux;

TcFilter::TcFilter(string buf, string dev, string ipt_my_chain, string ipt_section_hook)
{
    FilterId = 0;
    FlowId = 0;
    HandleFw = 0;
    Chains = 0;
    ImqAutoRedirect = true;
    DestMark = false;
    FromLocal = false;
    ToLocal = false;
    UseTcFilter = true;
    UseForFw = false;
    IptApplied = false;
    IptTarget = IPT_ACCEPT;
    Dev = dev;
    DevId = ifaces->index(Dev);
    TcFilterType = ifaces->tcFilterType(Dev);
    IptMyChain = ipt_my_chain;
    IptSectionHook = ipt_section_hook;
    Match = "";    
    unsigned int n=1;
    while ( awk(buf, ++n).size()) {
        if ( n >= 3 ) Match += " ";
        Match += awk(buf, n);
    }
    if (value_of_param( buf, "use-for-fw") == "yes") UseForFw = true;
    if (value_of_param( buf, "set-mark").size()) HandleFw = read_fw_mark( value_of_param( Match, "set-mark"));
    if (value_of_param( buf, "filterid").size()) FilterId = str_to_uint( value_of_param( Match, "filterid"));

    if (value_of_param(buf, "to-local").size() && test->validIp(value_of_param(buf, "to-local"))) ToLocal = true;
    else if (value_of_param(buf, "from-local").size() && test->validIp(value_of_param(buf, "from-local"))) FromLocal = true;
    
    memset(&TcU32Selector, 0, sizeof(TcU32Selector));
}

TcFilter::~TcFilter()
{
    string buf;

    if (IptApplied) {
        if (FromLocal && (IptSectionHook == "PREROUTING")) {
            buf = "iptables -t mangle -D OUTPUT " + IptMatch + " -j " + IptMyChain; 
            sys->ipt(buf); 
        }
        if (ToLocal && (IptSectionHook == "POSTROUTING")) {
            buf = "iptables -t mangle -D INPUT " + IptMatch + " -j " + IptMyChain; 
            sys->ipt(buf); 
        }
    }
}

int TcFilter::store(string buf)
{
    string option, param, value;

    option = awk( buf, 1);
    param = awk( buf, 2);
    value = awk( buf, 3);

    if ( option == "class" ) 
    {
        // blank 
    }
    else if ( option == "classid" ) {
        FlowId = str_to_uint( param);
    }
    else if ( option == "type" ) {
        if ( param == "standard-class" ) {
            UseTcFilter = true; 
        }   
        else if ( param == "wrapper" ) {
            UseTcFilter = true;
        }   
        else if ( param == "do-not-shape" ) { 
            UseTcFilter = true; 
            if (ifaces->isDNShapeMethodSafe(Dev)) FlowId = ifaces->htbDNWrapperId();
            else FlowId = 0;
        } 
        else if ( param == "virtual" ) { 
            UseTcFilter = false;
            IptTarget = IPT_FALSE; 
        } 
        else { log.error (11, buf); return -1; }
    }
    else if ( option == "imq" ) {
        if ( param == "autoredirect" ) {
            if ( value == "yes" ) ImqAutoRedirect = true;
            else if ( value == "no" ) ImqAutoRedirect = false;
            else if ( value == "true" ) {
                ImqAutoRedirect = true;
                log.warning( 10, buf );
            }
            else if (( value == "false" ) || ( value == "none" )) {
                ImqAutoRedirect = false;
                log.warning( 9, buf );
            }
            else { log.error( 11, buf ); }
        }
        else { log.error( 11, buf ); }
    }
    else if ( option == "iptables" )
    {
        if ( param == "target" ) 
        {
            if ( value == "accept" ) IptTarget = IPT_ACCEPT;
            else if ( value == "return" ) IptTarget = IPT_RETURN;
            else if ( value == "drop" ) IptTarget = IPT_DROP;
            else if ( value == "empty" ) IptTarget = IPT_FALSE;
            else { 
                log.error( 11, buf ); 
                return -1; 
            }
        }
        else {
            log.error( 11, buf );
            return -1;
        }
    }

    return 1;
}

int TcFilter::prepareIptFilter()
{
    bool need_mark;

    if (gen_ipt_filter(Match, need_mark, IptMatch) == -1) return -1;

    if ((need_mark) && (TcFilterType != FW)) {
        log.error(36, Match);
        return -1;
    }

    return 1;
}

int TcFilter::prepareIptRules(bool apply, std::vector <std::string> &ipt_rules)
{
    // From/To local rule
    if (FromLocal && (IptSectionHook == "PREROUTING")) {
        IptRule1 = " -A OUTPUT " + IptMatch + " -j " + IptMyChain;
    }
    else if (ToLocal && (IptSectionHook == "POSTROUTING")) {
        IptRule1 = " -A INPUT " + IptMatch + " -j " + IptMyChain;
    }
    else IptRule1 = "";

    // Marking rule
    if (TcFilterType == FW) {
        IptRule2 = " -A " + IptMyChain + " " + IptMatch + " -j MARK --set-mark " + int_to_str(HandleFw);
    }
    else IptRule2 = "";

    // IMQ redirect rule
    if ((Dev.find("imq") != string::npos) && ImqAutoRedirect) {
        IptRule3 = " -A " + IptMyChain + " " + IptMatch + " -j IMQ --todev " + &(Dev[3]);
    }
    else IptRule3 = "";

    // Common rule
    if ( IptTarget == IPT_DROP ) IptRule4 = " -A " + IptMyChain + " " + IptMatch + " -j DROP";
    else if ( IptTarget == IPT_RETURN ) IptRule4 = " -A " + IptMyChain + " " + IptMatch + " -j RETURN";
    else if ( IptTarget == IPT_FALSE ) IptRule4 = " -A " + IptMyChain + " " + IptMatch;
    else IptRule4 = " -A " + IptMyChain + " " + IptMatch + " -j ACCEPT";

    Chains = (IptRule2.size() ? 1 : 0) +(IptRule3.size() ? 1 : 0) + (IptRule4.size() ? 1 : 0);

    if (IptRule1.size()) { 
        if (apply) sys->ipt("iptables -t mangle " + IptRule1);
        else ipt_rules.push_back(IptRule1); 
    }
    if (IptRule2.size()) { 
        if (apply) sys->ipt("iptables -t mangle " + IptRule2);
        else ipt_rules.push_back(IptRule2); 
    }
    if (IptRule3.size()) {
        if (apply) sys->ipt("iptables -t mangle " + IptRule3);
        else ipt_rules.push_back(IptRule3); 
    }
    if (IptRule4.size()) { 
        if (apply) sys->ipt("iptables -t mangle " + IptRule4);
        else ipt_rules.push_back(IptRule4); 
    }

    IptApplied = true;

    return 1;
}

int TcFilter::prepareTcFilter()
{
    string option, value;
    string addr, mask;
    unsigned int n=0;

    TcFilterAdd = "";
    TcFilterDel = "";
    if (!UseTcFilter) return 0;
    if (TcFilterType == U32) {
        if (!g_fallback_to_tc) {
            TcU32Selector.sel.flags |= TC_U32_TERMINAL;
            while ((awk( Match, ++n)).size()) {
                option = awk( Match, n );
                value = awk( Match, ++n );
                if ( option == "proto" ) {
                    if ( value == "tcp" ) { 
                        // match ip protocol 6 0xff
                        if (parseU8(&TcU32Selector.sel, 9, 0, "6", "0xff") == -1) { log.error(66, Match); return -1; }
                    }                     
                    else if ( value == "udp" ) { 
                        // match ip protocol 17 0xff
                        if (parseU8(&TcU32Selector.sel, 9, 0, "17", "0xff") == -1) { log.error(66, Match); return -1; }
                    } 
                    else if ( value == "icmp" ) { 
                        // match ip protocol 1 0xff
                        if (parseU8(&TcU32Selector.sel, 9, 0, "1", "0xff" ) == -1) { log.error(66, Match); return -1; }
                    } 
                    else { log.error(67, Match); return -1; }
                }
                else if ((option == "srcip") || (option == "from-local")) {
                    // match ip dst " + addr + "/" + mask
                    if (split_ip(value, addr, mask) == -1) { log.error(60, Match); return -1; }
                    if (parseIpAddr(&TcU32Selector.sel, 12, addr.c_str(), mask) == -1) { log.error(66, Match); return -1; } 
                }
                else if ((option == "dstip") || (option == "to-local")) {
                    // match ip dst " + addr + "/" + mask
                    if (split_ip(value, addr, mask) == -1) { log.error(60, Match); return -1; }
                    if (parseIpAddr(&TcU32Selector.sel, 16, addr.c_str(), mask) == -1) { log.error(66, Match); return -1; } 
                }
                else if (( option == "srcport" ) || ( option == "sport" )) {
                    // match ip sport " + value + " 0xffff
                    if (parseU16(&TcU32Selector.sel, 20, 0, value.c_str(), "0xffff") == -1) { log.error(66, Match); return -1; } 
                }
                else if (( option == "dstport" ) || ( option == "dport" )) {
                    // match ip dport " + value + " 0xffff
                    if (parseU16(&TcU32Selector.sel, 22, 0, value.c_str(), "0xffff") == -1) { log.error(66, Match); return -1; } 
                }
            }
        }
        else 
        {
            TcFilterAdd = "tc filter add dev " + Dev + " protocol ip parent 1:0 prio 10 handle 800::" + int_to_hex(FilterId) + " u32 ";
            TcFilterDel = "tc filter del dev " + Dev + " protocol ip parent 1:0 prio 10 handle 800::" + int_to_hex(FilterId) + " u32 ";

            while ((awk( Match, ++n)).size()) {
                option = awk( Match, n );
                value = awk( Match, ++n );
                if ( option == "proto" ) {
                    if ( value == "tcp" ) TcFilterAdd += " match ip protocol 6 0xff ";
                    else if ( value == "udp" ) TcFilterAdd += " match ip protocol 17 0xff ";
                    else if ( value == "icmp" ) TcFilterAdd += " match ip protocol 1 0xff ";
                }
                else if ((option == "srcip") || (option == "from-local")){
                    if ( split_ip(value, addr, mask) == -1 ) return -1;
                    if ( test->solidIpMask( mask )) TcFilterAdd += " match ip src " + addr + "/" + int_to_str(dot_to_bit(mask));
                    else { log.error( 37, Match ); return -1; }
                }
                else if ((option == "dstip" ) || (option == "to-local")){
                    if ( split_ip(value, addr, mask) == -1 ) return -1;
                    if ( test->solidIpMask( mask )) TcFilterAdd += " match ip dst " + addr + "/" + int_to_str(dot_to_bit(mask));
                    else { log.error( 37, Match ); return -1; }
                }
                else if (( option == "srcport" ) || ( option == "sport" )) {
                    TcFilterAdd += " match ip sport " + value + " 0xffff ";
                }
                else if (( option == "dstport" ) || ( option == "dport" )) {
                    TcFilterAdd += " match ip dport " + value + " 0xffff ";
                }
            }

            TcFilterAdd += " flowid 1:" + int_to_hex(FlowId);
        }
    } 
    else if (TcFilterType == FW) { 
        if (!HandleFw) HandleFw = FilterId;
        if (g_fallback_to_tc) {
            TcFilterAdd = "tc filter add dev " + Dev + " protocol ip parent 1:0 prio 10 handle 0x" + int_to_hex(HandleFw) + " fw flowid 1:" + int_to_hex(FlowId);
            TcFilterDel = "tc filter del dev " + Dev + " protocol ip parent 1:0 prio 10 handle 0x" + int_to_hex(HandleFw) + " fw";
        }
    }

    return 1;
}

unsigned int TcFilter::check(vector < string > &iptables_summary)
{
    register unsigned int n, m;
    static unsigned int state, value;    
    static char buf[MAX_LONG_BUF_SIZE];

    if ( iptables_summary.size() < Chains ) {
        log.error(12);
        return 0;
    }

    for ( n=0; n<(Chains-1); n++ ) {
        iptables_summary.pop_back();
    }

    strncpy ( buf, (iptables_summary.back()).c_str() , MAX_SHORT_BUF_SIZE );
    iptables_summary.pop_back();

    state=value=0;
    for ( n=0; n<MAX_SHORT_BUF_SIZE ; n++ ) {
        m=buf[n];
        if (( m>=65 && m<=90 ) || ( m>=97 && m<=122 )) { state=n; break; }
    }
    if ( !state ) { log.error(12); return 0; }

    for ( n=state; n>0; n-- ) {
        m=buf[n];
        if ( m>=48 && m<=57 ) { state=n; break; }
    }
    if (!n) { log.error( 12 ); return 0; }

    unsigned int mul=1;
    for ( n=state; n>0; n-- ) {
        m=buf[n];
        if ( m>=48 && m<=57 ) {
            m-=48;
            value+=mul*m;
            mul*=10;
        }
        else break;
    }

    return value << 3;
}

unsigned int TcFilter::add()
{
    if (TcFilterType == U32) {
        if (!g_fallback_to_tc) sys->tcFilter(TC_ADD, DevId, FilterId, FlowId, TcFilterType, &TcU32Selector);
        else sys->tc(TcFilterAdd);
    }
    if ((TcFilterType == FW) && (UseForFw)) {
        if (!g_fallback_to_tc) sys->tcFilter(TC_ADD, DevId, HandleFw, FlowId, TcFilterType, NULL);
        else sys->tc(TcFilterAdd);
       
    }
    else return 0;

    return 1;
}

unsigned int TcFilter::del()
{
    if (TcFilterType == U32) {    
        if (!g_fallback_to_tc) sys->tcFilter(TC_DEL, DevId, FilterId, 0, TcFilterType, NULL);
        else sys->tc(TcFilterDel);

    }
    if ((TcFilterType == FW) && (UseForFw)) {
        if (!g_fallback_to_tc) sys->tcFilter(TC_DEL, DevId, HandleFw, 0, TcFilterType, NULL);
        else sys->tc(TcFilterDel);
    }
    else return 0;

    return 1;
}

/*
 *      This program is free software; you can redistribute it and/or
 *      modify it under the terms of the GNU General Public License
 *      as published by the Free Software Foundation; either version
 *      2 of the License, or (at your option) any later version.
 *
 * Authors: Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
 *          Mariusz Jedwabny, <mariusz@jedwabny.net> - Adapt to NiceShaper
 */

int TcFilter::parseIpAddr(struct tc_u32_sel *sel, int off, const char *param1, string param2)
{
    inet_prefix addr;
    __u32 mask;
    int offmask = 0;
    int mask_len = 32;

    if (!test->solidIpMask(string(param2))) {
        log.error( 37, Match );
        return -1;
    }

    if(getAddr1(&addr, param1) == -1) return -1;
    addr.bitlen = mask_len;

    addr.flags |= PREFIXLEN_SPECIFIED;
    if ((addr.bitlen = dot_to_bit(param2)) > mask_len) return -1;

    mask = 0;
    if (addr.bitlen) mask = htonl(0xFFFFFFFF<<(32-addr.bitlen));
    if (packKey(sel, addr.data[0], mask, off, offmask) == -1) return -1;

    return 0;
}

int TcFilter::parseU16(struct tc_u32_sel *sel, int off, int offmask, const char *param1, const char *param2)
{
    __u32 key;
    __u32 mask;

    if (getU32(&key, param1, 0)) return -1;

    if (getU32(&mask, param2, 16)) return -1;

    if (packKey16(sel, key, mask, off, offmask) == -1) return -1;

    return 0;
}

int TcFilter::parseU8(struct tc_u32_sel *sel, int off, int offmask, const char *param1, const char *param2)
{
    __u32 key;
    __u32 mask;

    if (getU32(&key, param1, 0)) return -1;

    if (getU32(&mask, param2, 16)) return -1;

    if (key > 0xFF || mask > 0xFF) return -1;

    if (packKey8(sel, key, mask, off, offmask) == -1) return -1;

    return 0;
}

int TcFilter::packKey(struct tc_u32_sel *sel, __u32 key, __u32 mask, int off, int offmask)
{
    int i;
    int hwm = sel->nkeys;

    key &= mask;

    for (i=0; i<hwm; i++) {
        if (sel->keys[i].off == off && sel->keys[i].offmask == offmask) {
            __u32 intersect = mask&sel->keys[i].mask;

            if ((key^sel->keys[i].val) & intersect)
                return -1;
            sel->keys[i].val |= key;
            sel->keys[i].mask |= mask;
            return 0;
        }
    }

    if (hwm >= 128)
        return -1;
    if (off % 4)
        return -1;
    sel->keys[hwm].val = key;
    sel->keys[hwm].mask = mask;
    sel->keys[hwm].off = off;
    sel->keys[hwm].offmask = offmask;
    sel->nkeys++;
    return 0;
}

int TcFilter::packKey16(struct tc_u32_sel *sel, __u32 key, __u32 mask, int off, int offmask)
{
    if (key > 0xFFFF || mask > 0xFFFF)
        return -1;

    if ((off & 3) == 0) {
        key <<= 16;
        mask <<= 16;
    }
    off &= ~3;
    key = htonl(key);
    mask = htonl(mask);

    return packKey(sel, key, mask, off, offmask);
}

int TcFilter::packKey8(struct tc_u32_sel *sel, __u32 key, __u32 mask, int off, int offmask)
{
    if (key > 0xFF || mask > 0xFF)
        return -1;

    if ((off & 3) == 0) {
        key <<= 24;
        mask <<= 24;
    } else if ((off & 3) == 1) {
        key <<= 16;
        mask <<= 16;
    } else if ((off & 3) == 2) {
        key <<= 8;
        mask <<= 8;
    }
    off &= ~3;
    key = htonl(key);
    mask = htonl(mask);

    return packKey(sel, key, mask, off, offmask);
}

int TcFilter::getU32(__u32 *val, const char *arg, int base)
{
        unsigned long res;
        char *ptr;

        if (!arg || !*arg)
                return -1;
        res = strtoul(arg, &ptr, base);
        if (!ptr || ptr == arg || *ptr || res > 0xFFFFFFFFUL)
                return -1;
        *val = res;
        return 0;
}

int TcFilter::getAddr1(inet_prefix *addr, const char *name)
{
    memset(addr, 0, sizeof(*addr));

    addr->family = AF_INET;

    if (getAddrIpv4((__u8 *)addr->data, name) <= 0)
        return -1;

    addr->bytelen = 4;
    addr->bitlen = -1;
    return 0;
}

/* This uses a non-standard parsing (ie not inet_aton, or inet_pton)
 * because of legacy choice to parse 10.8 as 10.8.0.0 not 10.0.0.8
 */
int TcFilter::getAddrIpv4(__u8 *ap, const char *cp)
{
    int i;

    for (i = 0; i < 4; i++) {
        unsigned long n;
        char *endp;

        n = strtoul(cp, &endp, 0);
        if (n > 255)
            return -1;  /* bogus network value */

        if (endp == cp) /* no digits */
            return -1;

        ap[i] = n;

        if (*endp == '\0')
            break;

        if (i == 3 || *endp != '.')
            return -1;  /* extra characters */
        cp = endp + 1;
    }

    return 1;
}

