当前位置: 代码网 > it编程>编程语言>Java > (AtCoder Beginner Contest 331) --- F - Palindrome Query -- -题解 (线段树 + 哈希)

(AtCoder Beginner Contest 331) --- F - Palindrome Query -- -题解 (线段树 + 哈希)

2024年07月31日 Java 我要评论
那么题目要求对于某个点进行修改,然后查询这个区间是否是线段树,用到上一个性质,那我们可以发现这道题就是一个典型的单点修改+区间查询。如果 a1a2a3 是一个回文串,一定有 a1 * p * p + a2 * p + a3 == a1 + a2 * p + a3 * p * p,对于任意长度的回文串这个性质都是显然的。例如: ab + cd 这两个区间, ab的正哈希 为 a*p+b, cd的正哈希为c*p+d. abcd的正哈希为 a*p*p*p+ b*p*p + c*p+d,这样就可以看出规律了。

f - palindrome query

        题目大意:

                

 

思路解析:

        如果 a1a2a3 是一个回文串,一定有 a1 * p * p + a2 * p + a3 == a1 + a2 * p + a3 * p * p,对于任意长度的回文串这个性质都是显然的。

        那么题目要求对于某个点进行修改,然后查询这个区间是否是线段树,用到上一个性质,那我们可以发现这道题就是一个典型的单点修改+区间查询。 利用线段树维护 a-b这个区间的正反哈希,但是修改后,如何合并两个区间,

例如: ab  + cd 这两个区间, ab的正哈希 为 a*p+b, cd的正哈希为c*p+d. abcd的正哈希为 a*p*p*p+ b*p*p + c*p+d,这样就可以看出规律了。

 代码实现:

        

import java.io.*;
import java.util.*;


public class main {
    static long mod = (int) 1e9 + 7;
    static int base = 131;
    static long[] pow = new long[1000005];
    static char[] s;
    static int maxn = 1000005;


    public static void main(string[] args) throws ioexception {
        fastscanner f = new fastscanner();
        printwriter w = new printwriter(system.out);
        int n = f.nextint();
        int q = f.nextint();
        pow[0] = 1;
        for (int i = 1; i <= 1000000; i++) {
            pow[i] = pow[i - 1] * base % mod;
        }
        char[] a = f.nextstring().tochararray();
        s = new char[n+1];
        for (int i = 1; i <= n; i++) {
            s[i] = a[i-1];
        }
        segtree seg = new segtree();
        seg.build(1, 1, n);

        for (int i = 0; i < q; i++) {
            int op = f.nextint();
            if (op == 1){
                int x = f.nextint();
                char c = f.nextchar();
                seg.change(1, x, x, c);
            }else {
                int l = f.nextint();
                int r = f.nextint();
                long[] aa = seg.sum(1, l, r);
                if (aa[0] == aa[1]) w.println("yes");
                else w.println("no");
            }
        }
        w.flush();
        w.close();

    }

    public static class node {
        int l, r;
        long pre, suf;
    }

    public static class segtree {
        node[] t = new node[4 * maxn];

        public segtree() {
            for (int i = 0; i < 4 * maxn; i++) {
                t[i] = new node();
            }
        }

        public void build(int root, int l, int r) {
            t[root].l = l;
            t[root].r = r;
            if (l == r) {
                t[root].pre = s[l];
                t[root].suf = s[l];
                return;
            }
            int ch = root << 1;
            int mid = (l + r) >> 1;
            build(ch, l, mid);
            build(ch + 1, mid + 1, r);
            update(root);
        }

        public void update(int root) {
            int ch = root << 1;
            t[root].pre = (t[ch].pre * pow[t[ch + 1].r - t[ch + 1].l + 1] % mod + t[ch + 1].pre) % mod; // 正哈希
            t[root].suf = (t[ch + 1].suf * pow[t[ch].r - t[ch].l + 1] % mod + t[ch].suf) % mod; // 反哈希
        }

        public void change(int root, int l, int r, char x) {
            if (t[root].l == l && t[root].r == r) {
                t[root].suf = x ;
                t[root].pre = x;
                return;
            }
            int ch = root << 1;
            int mid = (t[root].l + t[root].r) >> 1;
            if (r <= mid) change(ch, l, r, x);
            else change(ch + 1, l, r, x);
            update(root);
        }

        public long[] sum(int root, int l, int r) {
            if (t[root].l == l && t[root].r == r) {
                return new long[]{
                        t[root].pre, t[root].suf
                };
            }
            int ch = root << 1;
            int mid = (t[root].l + t[root].r) >> 1;
            if (r <= mid) return sum(ch,l,r);
            else if (l > mid) return sum(ch+1,l,r);
            else {
                long[] a = sum(ch,l,mid);
                long[] b = sum(ch+1,mid+1,r);
                return new long[]{ (a[0] * pow[r - mid] % mod + b[0]) % mod, (a[1] + b[1] * pow[mid - l + 1] % mod) % mod};
            }
        }
    }

    private static class fastscanner {
        final private int buffer_size = 1 << 16;
        private datainputstream din;
        private byte[] buffer;
        private int bufferpointer, bytesread;

        private fastscanner() throws ioexception {
            din = new datainputstream(system.in);
            buffer = new byte[buffer_size];
            bufferpointer = bytesread = 0;
        }

        private short nextshort() throws ioexception {
            short ret = 0;
            byte c = read();
            while (c <= ' ') c = read();
            boolean neg = (c == '-');
            if (neg) c = read();
            do ret = (short) (ret * 10 + c - '0');
            while ((c = read()) >= '0' && c <= '9');
            if (neg) return (short) -ret;
            return ret;
        }

        private int nextint() throws ioexception {
            int ret = 0;
            byte c = read();
            while (c <= ' ') c = read();
            boolean neg = (c == '-');
            if (neg) c = read();
            do ret = ret * 10 + c - '0';
            while ((c = read()) >= '0' && c <= '9');
            if (neg) return -ret;
            return ret;
        }

        public long nextlong() throws ioexception {
            long ret = 0;
            byte c = read();
            while (c <= ' ') c = read();
            boolean neg = (c == '-');
            if (neg) c = read();
            do ret = ret * 10 + c - '0';
            while ((c = read()) >= '0' && c <= '9');
            if (neg) return -ret;
            return ret;
        }

        private char nextchar() throws ioexception {
            byte c = read();
            while (c <= ' ') c = read();
            return (char) c;
        }

        private string nextstring() throws ioexception {
            stringbuilder ret = new stringbuilder();
            byte c = read();
            while (c <= ' ') c = read();
            do {
                ret.append((char) c);
            } while ((c = read()) > ' ');
            return ret.tostring();
        }

        private void fillbuffer() throws ioexception {
            bytesread = din.read(buffer, bufferpointer = 0, buffer_size);
            if (bytesread == -1) buffer[0] = -1;
        }

        private byte read() throws ioexception {
            if (bufferpointer == bytesread) fillbuffer();
            return buffer[bufferpointer++];
        }
    }
}

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com