diff graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/type/StampTool.java @ 11352:8185c119d731

"always set" bit mask on IntegerStamps
author Lukas Stadler <lukas.stadler@jku.at>
date Fri, 16 Aug 2013 13:15:42 +0200
parents ef6915cf1e59
children 90201030d3cf
line wrap: on
line diff
--- a/graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/type/StampTool.java	Fri Aug 16 12:09:36 2013 +0200
+++ b/graal/com.oracle.graal.nodes/src/com/oracle/graal/nodes/type/StampTool.java	Fri Aug 16 13:15:42 2013 +0200
@@ -29,21 +29,37 @@
 /**
  * Helper class that is used to keep all stamp-related operations in one place.
  */
-// TODO(ls) maybe move the contents into IntegerStamp
 public class StampTool {
 
     public static Stamp negate(Stamp stamp) {
+        if (true) {
+            return StampFactory.forKind(stamp.kind());
+        }
         Kind kind = stamp.kind();
         if (stamp instanceof IntegerStamp) {
             IntegerStamp integerStamp = (IntegerStamp) stamp;
             if (integerStamp.lowerBound() != kind.getMinValue()) {
                 // TODO(ls) check if the mask calculation is correct...
-                return new IntegerStamp(kind, -integerStamp.upperBound(), -integerStamp.lowerBound(), IntegerStamp.defaultMask(kind) & (integerStamp.mask() | -integerStamp.mask()));
+                return StampFactory.forInteger(kind, -integerStamp.upperBound(), -integerStamp.lowerBound());
             }
+        } else if (stamp instanceof FloatStamp) {
+            FloatStamp floatStamp = (FloatStamp) stamp;
+            return new FloatStamp(kind, -floatStamp.upperBound(), -floatStamp.lowerBound(), floatStamp.isNonNaN());
         }
+
         return StampFactory.forKind(kind);
     }
 
+    public static Stamp not(Stamp stamp) {
+        if (stamp instanceof IntegerStamp) {
+            IntegerStamp integerStamp = (IntegerStamp) stamp;
+            assert stamp.kind() == Kind.Int || stamp.kind() == Kind.Long;
+            long defaultMask = IntegerStamp.defaultMask(stamp.kind());
+            return new IntegerStamp(stamp.kind(), ~integerStamp.upperBound(), ~integerStamp.lowerBound(), (~integerStamp.upMask()) & defaultMask, (~integerStamp.downMask()) & defaultMask);
+        }
+        return StampFactory.forKind(stamp.kind());
+    }
+
     public static Stamp meet(Collection<? extends StampProvider> values) {
         Iterator<? extends StampProvider> iterator = values.iterator();
         if (iterator.hasNext()) {
@@ -64,20 +80,8 @@
         return StampFactory.illegal();
     }
 
-    public static Stamp add(IntegerStamp stamp1, IntegerStamp stamp2) {
-        Kind kind = stamp1.kind();
-        assert kind == stamp2.kind();
-        if (addOverflow(stamp1.lowerBound(), stamp2.lowerBound(), kind)) {
-            return StampFactory.forKind(kind);
-        }
-        if (addOverflow(stamp1.upperBound(), stamp2.upperBound(), kind)) {
-            return StampFactory.forKind(kind);
-        }
-        long lowerBound = stamp1.lowerBound() + stamp2.lowerBound();
-        long upperBound = stamp1.upperBound() + stamp2.upperBound();
-        long mask = IntegerStamp.maskFor(kind, lowerBound, upperBound);
-
-        return StampFactory.forInteger(kind, lowerBound, upperBound, mask);
+    private static long carryBits(long x, long y) {
+        return (x + y) ^ x ^ y;
     }
 
     public static Stamp sub(Stamp stamp1, Stamp stamp2) {
@@ -92,16 +96,17 @@
     }
 
     public static Stamp div(IntegerStamp stamp1, IntegerStamp stamp2) {
+        assert stamp1.kind() == stamp2.kind();
         Kind kind = stamp1.kind();
         if (stamp2.isStrictlyPositive()) {
             long lowerBound = stamp1.lowerBound() / stamp2.lowerBound();
             long upperBound = stamp1.upperBound() / stamp2.lowerBound();
-            return StampFactory.forInteger(kind, lowerBound, upperBound, IntegerStamp.maskFor(kind, lowerBound, upperBound));
+            return StampFactory.forInteger(kind, lowerBound, upperBound);
         }
         return StampFactory.forKind(kind);
     }
 
-    private static boolean addOverflow(long x, long y, Kind kind) {
+    private static boolean addOverflows(long x, long y, Kind kind) {
         long result = x + y;
         if (kind == Kind.Long) {
             return ((x ^ result) & (y ^ result)) < 0;
@@ -111,29 +116,64 @@
         }
     }
 
-    private static final long INTEGER_SIGN_BIT = 0x80000000L;
-    private static final long LONG_SIGN_BIT = 0x8000000000000000L;
+    public static IntegerStamp add(IntegerStamp stamp1, IntegerStamp stamp2) {
+        try {
+            if (stamp1.isUnrestricted() || stamp2.isUnrestricted()) {
+                return (IntegerStamp) StampFactory.forKind(stamp1.kind());
+            }
+            Kind kind = stamp1.kind();
+            assert stamp1.kind() == stamp2.kind();
+            long defaultMask = IntegerStamp.defaultMask(kind);
+            long variableBits = (stamp1.downMask() ^ stamp1.upMask()) | (stamp2.downMask() ^ stamp2.upMask());
+            long variableBitsWithCarry = variableBits | (carryBits(stamp1.downMask(), stamp2.downMask()) ^ carryBits(stamp1.upMask(), stamp2.upMask()));
+            long newDownMask = (stamp1.downMask() + stamp2.downMask()) & ~variableBitsWithCarry;
+            long newUpMask = (stamp1.downMask() + stamp2.downMask()) | variableBitsWithCarry;
+
+            newDownMask &= defaultMask;
+            newUpMask &= defaultMask;
 
-    private static Stamp stampForMask(Kind kind, long mask) {
-        return stampForMask(kind, mask, 0);
+            long lowerBound;
+            long upperBound;
+            if (addOverflows(stamp1.lowerBound(), stamp2.lowerBound(), kind) || addOverflows(stamp1.upperBound(), stamp2.upperBound(), kind)) {
+                lowerBound = kind.getMinValue();
+                upperBound = kind.getMaxValue();
+            } else {
+                lowerBound = stamp1.lowerBound() + stamp2.lowerBound();
+                upperBound = stamp1.upperBound() + stamp2.upperBound();
+            }
+            IntegerStamp limit = StampFactory.forInteger(kind, lowerBound, upperBound);
+            newUpMask &= limit.upMask();
+            return new IntegerStamp(kind, lowerBound | newDownMask, signExtend(upperBound & newUpMask, kind), newDownMask, newUpMask);
+        } catch (Throwable e) {
+            throw new RuntimeException(stamp1 + " + " + stamp2, e);
+        }
     }
 
-    private static Stamp stampForMask(Kind kind, long mask, long alwaysSetBits) {
+    public static Stamp sub(IntegerStamp stamp1, IntegerStamp stamp2) {
+        if (stamp1.isUnrestricted() || stamp2.isUnrestricted()) {
+            return StampFactory.forKind(stamp1.kind());
+        }
+        return add(stamp1, (IntegerStamp) StampTool.negate(stamp2));
+    }
+
+    private static Stamp stampForMask(Kind kind, long downMask, long upMask) {
         long lowerBound;
         long upperBound;
-        if (kind == Kind.Int && (mask & INTEGER_SIGN_BIT) != 0) {
-            // the mask is negative
-            lowerBound = Integer.MIN_VALUE;
-            upperBound = mask ^ INTEGER_SIGN_BIT;
-        } else if (kind == Kind.Long && (mask & LONG_SIGN_BIT) != 0) {
-            // the mask is negative
-            lowerBound = Long.MIN_VALUE;
-            upperBound = mask ^ LONG_SIGN_BIT;
+        if (((upMask >>> (kind.getBitCount() - 1)) & 1) == 0) {
+            lowerBound = downMask;
+            upperBound = upMask;
+        } else if (((downMask >>> (kind.getBitCount() - 1)) & 1) == 1) {
+            lowerBound = downMask;
+            upperBound = upMask;
         } else {
-            lowerBound = alwaysSetBits;
-            upperBound = mask;
+            lowerBound = upMask;
+            upperBound = kind.getMaxValue() & upMask;
         }
-        return StampFactory.forInteger(kind, lowerBound, upperBound, mask);
+        if (kind == Kind.Int) {
+            return StampFactory.forInteger(kind, (int) lowerBound, (int) upperBound, downMask, upMask);
+        } else {
+            return StampFactory.forInteger(kind, lowerBound, upperBound, downMask, upMask);
+        }
     }
 
     public static Stamp and(Stamp stamp1, Stamp stamp2) {
@@ -144,9 +184,8 @@
     }
 
     public static Stamp and(IntegerStamp stamp1, IntegerStamp stamp2) {
-        Kind kind = stamp1.kind();
-        long mask = stamp1.mask() & stamp2.mask();
-        return stampForMask(kind, mask);
+        assert stamp1.kind() == stamp2.kind();
+        return stampForMask(stamp1.kind(), stamp1.downMask() & stamp2.downMask(), stamp1.upMask() & stamp2.upMask());
     }
 
     public static Stamp or(Stamp stamp1, Stamp stamp2) {
@@ -157,13 +196,8 @@
     }
 
     public static Stamp or(IntegerStamp stamp1, IntegerStamp stamp2) {
-        Kind kind = stamp1.kind();
-        long mask = stamp1.mask() | stamp2.mask();
-        if (stamp1.lowerBound() >= 0 && stamp2.lowerBound() >= 0) {
-            return stampForMask(kind, mask, stamp1.lowerBound() | stamp2.lowerBound());
-        } else {
-            return stampForMask(kind, mask);
-        }
+        assert stamp1.kind() == stamp2.kind();
+        return stampForMask(stamp1.kind(), stamp1.downMask() | stamp2.downMask(), stamp1.upMask() | stamp2.upMask());
     }
 
     public static Stamp xor(Stamp stamp1, Stamp stamp2) {
@@ -174,9 +208,11 @@
     }
 
     public static Stamp xor(IntegerStamp stamp1, IntegerStamp stamp2) {
-        Kind kind = stamp1.kind();
-        long mask = stamp1.mask() | stamp2.mask();
-        return stampForMask(kind, mask);
+        assert stamp1.kind() == stamp2.kind();
+        long variableBits = (stamp1.downMask() ^ stamp1.upMask()) | (stamp2.downMask() ^ stamp2.upMask());
+        long newDownMask = (stamp1.downMask() ^ stamp2.downMask()) & ~variableBits;
+        long newUpMask = (stamp1.downMask() ^ stamp2.downMask()) | variableBits;
+        return stampForMask(stamp1.kind(), newDownMask, newUpMask);
     }
 
     public static Stamp unsignedRightShift(Stamp value, Stamp shift) {
@@ -201,12 +237,11 @@
                     lowerBound = value.lowerBound() >>> shiftCount;
                     upperBound = value.upperBound() >>> shiftCount;
                 }
-                long mask = value.mask() >>> shiftCount;
-                return StampFactory.forInteger(kind, lowerBound, upperBound, mask);
+                return new IntegerStamp(kind, lowerBound, upperBound, value.downMask() >>> shiftCount, value.upMask() >>> shiftCount);
             }
         }
-        long mask = IntegerStamp.maskFor(kind, value.lowerBound(), value.upperBound());
-        return stampForMask(kind, mask);
+        long mask = IntegerStamp.upMaskFor(kind, value.lowerBound(), value.upperBound());
+        return stampForMask(kind, 0, mask);
     }
 
     public static Stamp leftShift(Stamp value, Stamp shift) {
@@ -218,55 +253,78 @@
 
     public static Stamp leftShift(IntegerStamp value, IntegerStamp shift) {
         Kind kind = value.kind();
+        long defaultMask = IntegerStamp.defaultMask(kind);
+        if (value.upMask() == 0) {
+            return value;
+        }
         int shiftBits = kind == Kind.Int ? 5 : 6;
         long shiftMask = kind == Kind.Int ? 0x1FL : 0x3FL;
         if ((shift.lowerBound() >>> shiftBits) == (shift.upperBound() >>> shiftBits)) {
-            long mask = 0;
-            for (long i = shift.lowerBound() & shiftMask; i <= (shift.upperBound() & shiftMask); i++) {
-                mask |= value.mask() << i;
+            long downMask = defaultMask;
+            long upMask = 0;
+            for (long i = shift.lowerBound(); i <= shift.upperBound(); i++) {
+                if (shift.contains(i)) {
+                    downMask &= value.downMask() << (i & shiftMask);
+                    upMask |= value.upMask() << (i & shiftMask);
+                }
             }
-            mask &= IntegerStamp.defaultMask(kind);
-            return stampForMask(kind, mask);
+            Stamp result = stampForMask(kind, downMask, upMask & IntegerStamp.defaultMask(kind));
+            return result;
         }
         return StampFactory.forKind(kind);
     }
 
     public static Stamp intToLong(IntegerStamp intStamp) {
-        long mask;
-        if (intStamp.isPositive()) {
-            mask = intStamp.mask();
-        } else {
-            mask = 0xffffffff00000000L | intStamp.mask();
-        }
-        return StampFactory.forInteger(Kind.Long, intStamp.lowerBound(), intStamp.upperBound(), mask);
+        return StampFactory.forInteger(Kind.Long, intStamp.lowerBound(), intStamp.upperBound(), signExtend(intStamp.downMask(), Kind.Int), signExtend(intStamp.upMask(), Kind.Int));
     }
 
-    private static Stamp narrowingKindConvertion(IntegerStamp fromStamp, Kind toKind) {
-        long mask = fromStamp.mask() & IntegerStamp.defaultMask(toKind);
-        long lowerBound = saturate(fromStamp.lowerBound(), toKind);
-        long upperBound = saturate(fromStamp.upperBound(), toKind);
+    private static IntegerStamp narrowingKindConvertion(IntegerStamp fromStamp, Kind toKind) {
+        assert toKind == Kind.Byte || toKind == Kind.Char || toKind == Kind.Short || toKind == Kind.Int;
+        final long upperBound;
         if (fromStamp.lowerBound() < toKind.getMinValue()) {
             upperBound = toKind.getMaxValue();
+        } else {
+            upperBound = saturate(fromStamp.upperBound(), toKind);
         }
+        final long lowerBound;
         if (fromStamp.upperBound() > toKind.getMaxValue()) {
             lowerBound = toKind.getMinValue();
+        } else {
+            lowerBound = saturate(fromStamp.lowerBound(), toKind);
         }
-        return StampFactory.forInteger(toKind.getStackKind(), lowerBound, upperBound, mask);
+
+        long defaultMask = IntegerStamp.defaultMask(toKind);
+        long intMask = IntegerStamp.defaultMask(Kind.Int);
+        long newUpMask = signExtend(fromStamp.upMask() & defaultMask, toKind) & intMask;
+        long newDownMask = signExtend(fromStamp.downMask() & defaultMask, toKind) & intMask;
+        return new IntegerStamp(toKind.getStackKind(), (int) ((lowerBound | newDownMask) & newUpMask), (int) ((upperBound | newDownMask) & newUpMask), newDownMask, newUpMask);
     }
 
-    public static Stamp intToByte(IntegerStamp intStamp) {
+    private static long signExtend(long value, Kind valueKind) {
+        if (valueKind != Kind.Char && (value >>> (valueKind.getBitCount() - 1) & 1) == 1) {
+            return value | (-1L << valueKind.getBitCount());
+        } else {
+            return value;
+        }
+    }
+
+    public static IntegerStamp intToByte(IntegerStamp intStamp) {
+        assert intStamp.kind() == Kind.Int;
         return narrowingKindConvertion(intStamp, Kind.Byte);
     }
 
-    public static Stamp intToShort(IntegerStamp intStamp) {
+    public static IntegerStamp intToShort(IntegerStamp intStamp) {
+        assert intStamp.kind() == Kind.Int;
         return narrowingKindConvertion(intStamp, Kind.Short);
     }
 
-    public static Stamp intToChar(IntegerStamp intStamp) {
+    public static IntegerStamp intToChar(IntegerStamp intStamp) {
+        assert intStamp.kind() == Kind.Int;
         return narrowingKindConvertion(intStamp, Kind.Char);
     }
 
-    public static Stamp longToInt(IntegerStamp longStamp) {
+    public static IntegerStamp longToInt(IntegerStamp longStamp) {
+        assert longStamp.kind() == Kind.Long;
         return narrowingKindConvertion(longStamp, Kind.Int);
     }