From aae2eb120f5aa900e553ff39f1e2ffc8274cb747 Mon Sep 17 00:00:00 2001 From: Kyle Rudy Date: Thu, 23 Mar 2017 17:39:13 -0700 Subject: [PATCH] Quick fix to enable nested collections like map>> --- src/Database/CQL/Protocol/Codec.hs | 62 ++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/src/Database/CQL/Protocol/Codec.hs b/src/Database/CQL/Protocol/Codec.hs index 93b2fe8..c040665 100644 --- a/src/Database/CQL/Protocol/Codec.hs +++ b/src/Database/CQL/Protocol/Codec.hs @@ -468,9 +468,18 @@ putNative V4 (CqlTinyInt x) = put x putNative _ v@(CqlTinyInt _) = fail $ "putNative: tinyint: " ++ show v putNative V2 v@(CqlUdt _) = fail $ "putNative: udt: " ++ show v putNative v (CqlUdt x) = putByteString $ runPut (mapM_ (putValue v . snd) x) -putNative _ v@(CqlList _) = fail $ "putNative: collection type: " ++ show v -putNative _ v@(CqlSet _) = fail $ "putNative: collection type: " ++ show v -putNative _ v@(CqlMap _) = fail $ "putNative: collection type: " ++ show v +putNative v (CqlList x) = do + let v' = if (v >= V3) then v else V3 + encodeInt (fromIntegral (length x)) + mapM_ (toBytes 4 . putNative v') x +putNative v (CqlSet x) = do + let v' = if (v >= V3) then v else V3 + encodeInt (fromIntegral (length x)) + mapM_ (toBytes 4 . putNative v') x +putNative v (CqlMap x) = do + let v' = if (v >= V3) then v else V3 + encodeInt (fromIntegral (length x)) + forM_ x $ \(k, w) -> toBytes 4 (putNative v' k) >> toBytes 4 (putNative v' w) putNative _ v@(CqlMaybe _) = fail $ "putNative: collection type: " ++ show v putNative _ v@(CqlTuple _) = fail $ "putNative: tuple type: " ++ show v @@ -478,28 +487,28 @@ putNative _ v@(CqlTuple _) = fail $ "putNative: tuple type: " ++ show v -- Note: Empty lists, maps and sets are represented as null in cassandra. getValue :: Version -> ColumnType -> Get Value getValue v (ListColumn t) - | v >= V3 = CqlList <$> getList (do + | v >= V3 = CqlList <$> getList (withBytes 4 (do len <- decodeInt - replicateM (fromIntegral len) (withBytes 4 (getNative v t))) - | otherwise = CqlList <$> getList (do + replicateM (fromIntegral len) (withBytes 4 (getNative v t)))) + | otherwise = CqlList <$> getList (withBytes 4 (do len <- decodeShort - replicateM (fromIntegral len) (withBytes 2 (getNative v t))) + replicateM (fromIntegral len) (withBytes 2 (getNative v t)))) getValue v (SetColumn t) - | v >= V3 = CqlSet <$> getList (do + | v >= V3 = CqlSet <$> getList (withBytes 4 (do len <- decodeInt - replicateM (fromIntegral len) (withBytes 4 (getNative v t))) - | otherwise = CqlSet <$> getList (do + replicateM (fromIntegral len) (withBytes 4 (getNative v t)))) + | otherwise = CqlSet <$> getList (withBytes 4 (do len <- decodeShort - replicateM (fromIntegral len) (withBytes 2 (getNative v t))) + replicateM (fromIntegral len) (withBytes 2 (getNative v t)))) getValue v (MapColumn t u) - | v >= V3 = CqlMap <$> getList (do + | v >= V3 = CqlMap <$> getList (withBytes 4 (do len <- decodeInt replicateM (fromIntegral len) - ((,) <$> withBytes 4 (getNative v t) <*> withBytes 4 (getNative v u))) - | otherwise = CqlMap <$> getList (do + ((,) <$> withBytes 4 (getNative v t) <*> withBytes 4 (getNative v u)))) + | otherwise = CqlMap <$> getList (withBytes 4 (do len <- decodeShort replicateM (fromIntegral len) - ((,) <$> withBytes 2 (getNative v t) <*> withBytes 2 (getNative v u))) + ((,) <$> withBytes 2 (getNative v t) <*> withBytes 2 (getNative v u)))) getValue v (TupleColumn t) | v >= V3 = do b <- withBytes 4 remainingBytes either fail return $ flip runGet b $ CqlTuple <$> mapM (getValue v) t @@ -553,9 +562,22 @@ getNative v c@(UdtColumn _ x) let (n, t) = unzip x zip n <$> mapM (getValue v) t | otherwise = fail $ "getNative: udt: " ++ show c -getNative _ c@(ListColumn _) = fail $ "getNative: collection type: " ++ show c -getNative _ c@(SetColumn _) = fail $ "getNative: collection type: " ++ show c -getNative _ c@(MapColumn _ _) = fail $ "getNative: collection type: " ++ show c +getNative v (ListColumn t) = do + let v' = if v >= V3 then v else V3 + CqlList <$> getList (do + len <- decodeInt + replicateM (fromIntegral len) (withBytes 4 (getNative v' t))) +getNative v (SetColumn t) = do + let v' = if v >= V3 then v else V3 + CqlSet <$> getList (do + len <- decodeInt + replicateM (fromIntegral len) (withBytes 4 (getNative v' t))) +getNative v (MapColumn t u) = do + let v' = if v >= V3 then v else V3 + CqlMap <$> getList (do + len <- decodeInt + replicateM (fromIntegral len) + ((,) <$> withBytes 4 (getNative v' t) <*> withBytes 4 (getNative v' u))) getNative _ c@(MaybeColumn _) = fail $ "getNative: collection type: " ++ show c getNative _ c@(TupleColumn _) = fail $ "getNative: tuple type: " ++ show c @@ -563,7 +585,7 @@ getList :: Get [a] -> Get [a] getList m = do n <- lookAhead (get :: Get Int32) if n < 0 then uncheckedSkip 4 >> return [] - else withBytes 4 m + else m withBytes :: Int -> Get a -> Get a withBytes s p = do @@ -575,7 +597,7 @@ withBytes s p = do fail "withBytes: null" b <- getBytes n case runGet p b of - Left e -> fail $ "withBytes: " ++ e + Left e -> fail $ "withBytes(" ++ (show n) ++ "): " ++ e Right x -> return x remainingBytes :: Get ByteString -- GitLab